diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index 17877a447..39b6cc071 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -41,7 +41,6 @@ jq>=1.4.1,<2 jsonschema>1 keybert>=0.8.5 langchain_openai>=0.2.1 -litellm>=1.30,<=1.39.5 lxml>=4.9.3,<6.0 markdownify>=0.11.6,<0.12 motor>=3.3.1,<4 @@ -96,7 +95,6 @@ tree-sitter>=0.20.2,<0.21 tree-sitter-languages>=1.8.0,<2 upstash-redis>=1.1.0,<2 upstash-ratelimit>=1.1.0,<2 -vdms>=0.0.20 xata>=1.0.0a7,<2 xmltodict>=0.13.0,<0.14 nanopq==0.2.1 diff --git a/libs/community/langchain_community/__init__.py b/libs/community/langchain_community/__init__.py index 3368893c0..66ac16f73 100644 --- a/libs/community/langchain_community/__init__.py +++ b/libs/community/langchain_community/__init__.py @@ -1,4 +1,4 @@ -"""Main entrypoint into package.""" +"""Entrypoint into `langchain-community`.""" from importlib import metadata diff --git a/libs/community/langchain_community/adapters/__init__.py b/libs/community/langchain_community/adapters/__init__.py index 3834e1853..4404de26a 100644 --- a/libs/community/langchain_community/adapters/__init__.py +++ b/libs/community/langchain_community/adapters/__init__.py @@ -1,8 +1,8 @@ -"""**Adapters** are used to adapt LangChain models to other APIs. +"""Adapters are used to adapt LangChain models to other APIs. LangChain integrates with many model providers. -While LangChain has its own message and model APIs, -LangChain has also made it as easy as -possible to explore other models by exposing an **adapter** to adapt LangChain -models to the other APIs, as to the OpenAI API. + +While LangChain has its own message and model APIs, LangChain has also made it as easy +as possible to explore other models by exposing an **adapter** to adapt LangChain models +to the other APIs, such as to the OpenAI API. """ diff --git a/libs/community/langchain_community/adapters/openai.py b/libs/community/langchain_community/adapters/openai.py index a1b7b9c2c..3388dcc97 100644 --- a/libs/community/langchain_community/adapters/openai.py +++ b/libs/community/langchain_community/adapters/openai.py @@ -170,7 +170,7 @@ def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMess messages: List of dictionaries representing OpenAI messages Returns: - List of LangChain BaseMessage objects. + List of LangChain `BaseMessage` objects. """ return [convert_dict_to_message(m) for m in messages] diff --git a/libs/community/langchain_community/agent_toolkits/load_tools.py b/libs/community/langchain_community/agent_toolkits/load_tools.py index 510ee7fe6..e7f0c614b 100644 --- a/libs/community/langchain_community/agent_toolkits/load_tools.py +++ b/libs/community/langchain_community/agent_toolkits/load_tools.py @@ -35,10 +35,6 @@ from langchain_community.tools.google_jobs.tool import GoogleJobsQueryRun from langchain_community.tools.google_lens.tool import GoogleLensQueryRun from langchain_community.tools.google_scholar.tool import GoogleScholarQueryRun -from langchain_community.tools.google_search.tool import ( - GoogleSearchResults, - GoogleSearchRun, -) from langchain_community.tools.google_serper.tool import ( GoogleSerperResults, GoogleSerperRun, @@ -82,7 +78,6 @@ from langchain_community.utilities.google_jobs import GoogleJobsAPIWrapper from langchain_community.utilities.google_lens import GoogleLensAPIWrapper from langchain_community.utilities.google_scholar import GoogleScholarAPIWrapper -from langchain_community.utilities.google_search import GoogleSearchAPIWrapper from langchain_community.utilities.google_serper import GoogleSerperAPIWrapper from langchain_community.utilities.google_trends import GoogleTrendsAPIWrapper from langchain_community.utilities.graphql import GraphQLAPIWrapper @@ -310,10 +305,6 @@ def _get_wolfram_alpha(**kwargs: Any) -> BaseTool: return WolframAlphaQueryRun(api_wrapper=WolframAlphaAPIWrapper(**kwargs)) -def _get_google_search(**kwargs: Any) -> BaseTool: - return GoogleSearchRun(api_wrapper=GoogleSearchAPIWrapper(**kwargs)) - - def _get_merriam_webster(**kwargs: Any) -> BaseTool: return MerriamWebsterQueryRun(api_wrapper=MerriamWebsterAPIWrapper(**kwargs)) @@ -368,10 +359,6 @@ def _get_google_serper_results_json(**kwargs: Any) -> BaseTool: return GoogleSerperResults(api_wrapper=GoogleSerperAPIWrapper(**kwargs)) -def _get_google_search_results_json(**kwargs: Any) -> BaseTool: - return GoogleSearchResults(api_wrapper=GoogleSearchAPIWrapper(**kwargs)) - - def _get_searchapi(**kwargs: Any) -> BaseTool: return SearchAPIRun(api_wrapper=SearchApiAPIWrapper(**kwargs)) @@ -485,11 +472,6 @@ def _get_reddit_search(**kwargs: Any) -> BaseTool: } _EXTRA_OPTIONAL_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[str]]] = { "wolfram-alpha": (_get_wolfram_alpha, ["wolfram_alpha_appid"]), - "google-search": (_get_google_search, ["google_api_key", "google_cse_id"]), - "google-search-results-json": ( - _get_google_search_results_json, - ["google_api_key", "google_cse_id", "num_results"], - ), "searx-search-results-json": ( _get_searx_search_results_json, ["searx_host", "engines", "num_results", "aiosession"], diff --git a/libs/community/langchain_community/cache.py b/libs/community/langchain_community/cache.py index 0647776f1..80525d763 100644 --- a/libs/community/langchain_community/cache.py +++ b/libs/community/langchain_community/cache.py @@ -74,12 +74,6 @@ from langchain_core.outputs import ChatGeneration, Generation from langchain_core.utils import get_from_env -from langchain_community.utilities.astradb import ( - SetupMode as AstraSetupMode, -) -from langchain_community.utilities.astradb import ( - _AstraDBCollectionEnvironment, -) from langchain_community.vectorstores import ( AzureCosmosDBNoSqlVectorSearch, AzureCosmosDBVectorSearch, @@ -95,7 +89,6 @@ if TYPE_CHECKING: import momento import pymemcache - from astrapy.db import AstraDB, AsyncAstraDB from azure.cosmos.cosmos_client import CosmosClient from cassandra.cluster import Session as CassandraSession @@ -1611,178 +1604,6 @@ def get_md5(input_string: str) -> str: return hashlib.md5(input_string.encode()).hexdigest() -ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_cache" - - -@deprecated( - since="0.0.28", - removal="1.0", - alternative_import="langchain_astradb.AstraDBCache", -) -class AstraDBCache(BaseCache): - @staticmethod - def _make_id(prompt: str, llm_string: str) -> str: - return f"{_hash(prompt)}#{_hash(llm_string)}" - - def __init__( - self, - *, - collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME, - token: Optional[str] = None, - api_endpoint: Optional[str] = None, - astra_db_client: Optional[AstraDB] = None, - async_astra_db_client: Optional[AsyncAstraDB] = None, - namespace: Optional[str] = None, - pre_delete_collection: bool = False, - setup_mode: AstraSetupMode = AstraSetupMode.SYNC, - ): - """ - Cache that uses Astra DB as a backend. - - It uses a single collection as a kv store - The lookup keys, combined in the _id of the documents, are: - - prompt, a string - - llm_string, a deterministic str representation of the model parameters. - (needed to prevent same-prompt-different-model collisions) - - Args: - collection_name: name of the Astra DB collection to create/use. - token: API token for Astra DB usage. - api_endpoint: full URL to the API endpoint, - such as `https://-us-east1.apps.astra.datastax.com`. - astra_db_client: *alternative to token+api_endpoint*, - you can pass an already-created 'astrapy.db.AstraDB' instance. - async_astra_db_client: *alternative to token+api_endpoint*, - you can pass an already-created 'astrapy.db.AsyncAstraDB' instance. - namespace: namespace (aka keyspace) where the - collection is created. Defaults to the database's "default namespace". - setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or - OFF). - pre_delete_collection: whether to delete the collection - before creating it. If False and the collection already exists, - the collection will be used as is. - """ - self.astra_env = _AstraDBCollectionEnvironment( - collection_name=collection_name, - token=token, - api_endpoint=api_endpoint, - astra_db_client=astra_db_client, - async_astra_db_client=async_astra_db_client, - namespace=namespace, - setup_mode=setup_mode, - pre_delete_collection=pre_delete_collection, - ) - self.collection = self.astra_env.collection - self.async_collection = self.astra_env.async_collection - - def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: - self.astra_env.ensure_db_setup() - doc_id = self._make_id(prompt, llm_string) - item = self.collection.find_one( - filter={ - "_id": doc_id, - }, - projection={ - "body_blob": 1, - }, - )["data"]["document"] - return _loads_generations(item["body_blob"]) if item is not None else None - - async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: - await self.astra_env.aensure_db_setup() - doc_id = self._make_id(prompt, llm_string) - item = ( - await self.async_collection.find_one( - filter={ - "_id": doc_id, - }, - projection={ - "body_blob": 1, - }, - ) - )["data"]["document"] - return _loads_generations(item["body_blob"]) if item is not None else None - - def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: - self.astra_env.ensure_db_setup() - doc_id = self._make_id(prompt, llm_string) - blob = _dumps_generations(return_val) - self.collection.upsert( - { - "_id": doc_id, - "body_blob": blob, - }, - ) - - async def aupdate( - self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE - ) -> None: - await self.astra_env.aensure_db_setup() - doc_id = self._make_id(prompt, llm_string) - blob = _dumps_generations(return_val) - await self.async_collection.upsert( - { - "_id": doc_id, - "body_blob": blob, - }, - ) - - def delete_through_llm( - self, prompt: str, llm: LLM, stop: Optional[List[str]] = None - ) -> None: - """ - A wrapper around `delete` with the LLM being passed. - In case the llm.invoke(prompt) calls have a `stop` param, you should - pass it here - """ - llm_string = get_prompts( - {**llm.dict(), **{"stop": stop}}, - [], - )[1] - return self.delete(prompt, llm_string=llm_string) - - async def adelete_through_llm( - self, prompt: str, llm: LLM, stop: Optional[List[str]] = None - ) -> None: - """ - A wrapper around `adelete` with the LLM being passed. - In case the llm.invoke(prompt) calls have a `stop` param, you should - pass it here - """ - llm_string = ( - await aget_prompts( - {**llm.dict(), **{"stop": stop}}, - [], - ) - )[1] - return await self.adelete(prompt, llm_string=llm_string) - - def delete(self, prompt: str, llm_string: str) -> None: - """Evict from cache if there's an entry.""" - self.astra_env.ensure_db_setup() - doc_id = self._make_id(prompt, llm_string) - self.collection.delete_one(doc_id) - - async def adelete(self, prompt: str, llm_string: str) -> None: - """Evict from cache if there's an entry.""" - await self.astra_env.aensure_db_setup() - doc_id = self._make_id(prompt, llm_string) - await self.async_collection.delete_one(doc_id) - - def clear(self, **kwargs: Any) -> None: - self.astra_env.ensure_db_setup() - self.collection.clear() - - async def aclear(self, **kwargs: Any) -> None: - await self.astra_env.aensure_db_setup() - await self.async_collection.clear() - - -ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD = 0.85 -ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_semantic_cache" -ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16 - - _unset = ["unset"] @@ -1819,269 +1640,6 @@ def decorating_function(user_function: Callable) -> Callable: return decorating_function -@deprecated( - since="0.0.28", - removal="1.0", - alternative_import="langchain_astradb.AstraDBSemanticCache", -) -class AstraDBSemanticCache(BaseCache): - def __init__( - self, - *, - collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME, - token: Optional[str] = None, - api_endpoint: Optional[str] = None, - astra_db_client: Optional[AstraDB] = None, - async_astra_db_client: Optional[AsyncAstraDB] = None, - namespace: Optional[str] = None, - setup_mode: AstraSetupMode = AstraSetupMode.SYNC, - pre_delete_collection: bool = False, - embedding: Embeddings, - metric: Optional[str] = None, - similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD, - ): - """ - Cache that uses Astra DB as a vector-store backend for semantic - (i.e. similarity-based) lookup. - - It uses a single (vector) collection and can store - cached values from several LLMs, so the LLM's 'llm_string' is stored - in the document metadata. - - You can choose the preferred similarity (or use the API default). - The default score threshold is tuned to the default metric. - Tune it carefully yourself if switching to another distance metric. - - Args: - collection_name: name of the Astra DB collection to create/use. - token: API token for Astra DB usage. - api_endpoint: full URL to the API endpoint, - such as `https://-us-east1.apps.astra.datastax.com`. - astra_db_client: *alternative to token+api_endpoint*, - you can pass an already-created 'astrapy.db.AstraDB' instance. - async_astra_db_client: *alternative to token+api_endpoint*, - you can pass an already-created 'astrapy.db.AsyncAstraDB' instance. - namespace: namespace (aka keyspace) where the - collection is created. Defaults to the database's "default namespace". - setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or - OFF). - pre_delete_collection: whether to delete the collection - before creating it. If False and the collection already exists, - the collection will be used as is. - embedding: Embedding provider for semantic encoding and search. - metric: the function to use for evaluating similarity of text embeddings. - Defaults to 'cosine' (alternatives: 'euclidean', 'dot_product') - similarity_threshold: the minimum similarity for accepting a - (semantic-search) match. - """ - self.embedding = embedding - self.metric = metric - self.similarity_threshold = similarity_threshold - self.collection_name = collection_name - - # The contract for this class has separate lookup and update: - # in order to spare some embedding calculations we cache them between - # the two calls. - # Note: each instance of this class has its own `_get_embedding` with - # its own lru. - @lru_cache(maxsize=ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE) - def _cache_embedding(text: str) -> List[float]: - return self.embedding.embed_query(text=text) - - self._get_embedding = _cache_embedding - - @_async_lru_cache(maxsize=ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE) - async def _acache_embedding(text: str) -> List[float]: - return await self.embedding.aembed_query(text=text) - - self._aget_embedding = _acache_embedding - - embedding_dimension: Union[int, Awaitable[int], None] = None - if setup_mode == AstraSetupMode.ASYNC: - embedding_dimension = self._aget_embedding_dimension() - elif setup_mode == AstraSetupMode.SYNC: - embedding_dimension = self._get_embedding_dimension() - - self.astra_env = _AstraDBCollectionEnvironment( - collection_name=collection_name, - token=token, - api_endpoint=api_endpoint, - astra_db_client=astra_db_client, - async_astra_db_client=async_astra_db_client, - namespace=namespace, - setup_mode=setup_mode, - pre_delete_collection=pre_delete_collection, - embedding_dimension=embedding_dimension, - metric=metric, - ) - self.collection = self.astra_env.collection - self.async_collection = self.astra_env.async_collection - - def _get_embedding_dimension(self) -> int: - return len(self._get_embedding(text="This is a sample sentence.")) - - async def _aget_embedding_dimension(self) -> int: - return len(await self._aget_embedding(text="This is a sample sentence.")) - - @staticmethod - def _make_id(prompt: str, llm_string: str) -> str: - return f"{_hash(prompt)}#{_hash(llm_string)}" - - def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: - self.astra_env.ensure_db_setup() - doc_id = self._make_id(prompt, llm_string) - llm_string_hash = _hash(llm_string) - embedding_vector = self._get_embedding(text=prompt) - body = _dumps_generations(return_val) - # - self.collection.upsert( - { - "_id": doc_id, - "body_blob": body, - "llm_string_hash": llm_string_hash, - "$vector": embedding_vector, - } - ) - - async def aupdate( - self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE - ) -> None: - await self.astra_env.aensure_db_setup() - doc_id = self._make_id(prompt, llm_string) - llm_string_hash = _hash(llm_string) - embedding_vector = await self._aget_embedding(text=prompt) - body = _dumps_generations(return_val) - # - await self.async_collection.upsert( - { - "_id": doc_id, - "body_blob": body, - "llm_string_hash": llm_string_hash, - "$vector": embedding_vector, - } - ) - - def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: - hit_with_id = self.lookup_with_id(prompt, llm_string) - if hit_with_id is not None: - return hit_with_id[1] - else: - return None - - async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: - hit_with_id = await self.alookup_with_id(prompt, llm_string) - if hit_with_id is not None: - return hit_with_id[1] - else: - return None - - def lookup_with_id( - self, prompt: str, llm_string: str - ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: - """ - Look up based on prompt and llm_string. - If there are hits, return (document_id, cached_entry) for the top hit - """ - self.astra_env.ensure_db_setup() - prompt_embedding: List[float] = self._get_embedding(text=prompt) - llm_string_hash = _hash(llm_string) - - hit = self.collection.vector_find_one( - vector=prompt_embedding, - filter={ - "llm_string_hash": llm_string_hash, - }, - fields=["body_blob", "_id"], - include_similarity=True, - ) - - if hit is None or hit["$similarity"] < self.similarity_threshold: - return None - else: - generations = _loads_generations(hit["body_blob"]) - if generations is not None: - # this protects against malformed cached items: - return hit["_id"], generations - else: - return None - - async def alookup_with_id( - self, prompt: str, llm_string: str - ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: - """ - Look up based on prompt and llm_string. - If there are hits, return (document_id, cached_entry) for the top hit - """ - await self.astra_env.aensure_db_setup() - prompt_embedding: List[float] = await self._aget_embedding(text=prompt) - llm_string_hash = _hash(llm_string) - - hit = await self.async_collection.vector_find_one( - vector=prompt_embedding, - filter={ - "llm_string_hash": llm_string_hash, - }, - fields=["body_blob", "_id"], - include_similarity=True, - ) - - if hit is None or hit["$similarity"] < self.similarity_threshold: - return None - else: - generations = _loads_generations(hit["body_blob"]) - if generations is not None: - # this protects against malformed cached items: - return hit["_id"], generations - else: - return None - - def lookup_with_id_through_llm( - self, prompt: str, llm: LLM, stop: Optional[List[str]] = None - ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: - llm_string = get_prompts( - {**llm.dict(), **{"stop": stop}}, - [], - )[1] - return self.lookup_with_id(prompt, llm_string=llm_string) - - async def alookup_with_id_through_llm( - self, prompt: str, llm: LLM, stop: Optional[List[str]] = None - ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: - llm_string = ( - await aget_prompts( - {**llm.dict(), **{"stop": stop}}, - [], - ) - )[1] - return await self.alookup_with_id(prompt, llm_string=llm_string) - - def delete_by_document_id(self, document_id: str) -> None: - """ - Given this is a "similarity search" cache, an invalidation pattern - that makes sense is first a lookup to get an ID, and then deleting - with that ID. This is for the second step. - """ - self.astra_env.ensure_db_setup() - self.collection.delete_one(document_id) - - async def adelete_by_document_id(self, document_id: str) -> None: - """ - Given this is a "similarity search" cache, an invalidation pattern - that makes sense is first a lookup to get an ID, and then deleting - with that ID. This is for the second step. - """ - await self.astra_env.aensure_db_setup() - await self.async_collection.delete_one(document_id) - - def clear(self, **kwargs: Any) -> None: - self.astra_env.ensure_db_setup() - self.collection.clear() - - async def aclear(self, **kwargs: Any) -> None: - await self.astra_env.aensure_db_setup() - await self.async_collection.clear() - - class AzureCosmosDBSemanticCache(BaseCache): """Cache that uses Cosmos DB Mongo vCore vector-store backend""" diff --git a/libs/community/langchain_community/callbacks/context_callback.py b/libs/community/langchain_community/callbacks/context_callback.py index 3d4d2b6d1..362995828 100644 --- a/libs/community/langchain_community/callbacks/context_callback.py +++ b/libs/community/langchain_community/callbacks/context_callback.py @@ -41,7 +41,7 @@ class ContextCallbackHandler(BaseCallbackHandler): ImportError: if the `context-python` package is not installed. Chat Example: - >>> from langchain_community.llms import ChatOpenAI + >>> from langchain_openai import ChatOpenAI >>> from langchain_community.callbacks import ContextCallbackHandler >>> context_callback = ContextCallbackHandler( ... token="", @@ -60,7 +60,7 @@ class ContextCallbackHandler(BaseCallbackHandler): Chain Example: >>> from langchain_classic.chains import LLMChain - >>> from langchain_community.chat_models import ChatOpenAI + >>> from langchain_openai import ChatOpenAI >>> from langchain_community.callbacks import ContextCallbackHandler >>> context_callback = ContextCallbackHandler( ... token="", diff --git a/libs/community/langchain_community/callbacks/openai_info.py b/libs/community/langchain_community/callbacks/openai_info.py index aec6ba289..a4ccfce46 100644 --- a/libs/community/langchain_community/callbacks/openai_info.py +++ b/libs/community/langchain_community/callbacks/openai_info.py @@ -4,7 +4,6 @@ from enum import Enum, auto from typing import Any, Dict, List -from langchain_core._api import warn_deprecated from langchain_core.callbacks import BaseCallbackHandler from langchain_core.messages import AIMessage from langchain_core.outputs import ChatGeneration, LLMResult @@ -315,7 +314,6 @@ class TokenType(Enum): def standardize_model_name( model_name: str, - is_completion: bool = False, *, token_type: TokenType = TokenType.PROMPT, ) -> str: @@ -324,25 +322,12 @@ def standardize_model_name( Args: model_name: Model name to standardize. - is_completion: Whether the model is used for completion or not. - Defaults to False. Deprecated in favor of ``token_type``. token_type: Token type. Defaults to ``TokenType.PROMPT``. Returns: Standardized model name. """ - if is_completion: - warn_deprecated( - since="0.3.13", - message=( - "is_completion is deprecated. Use token_type instead. Example:\n\n" - "from langchain_community.callbacks.openai_info import TokenType\n\n" - "standardize_model_name('gpt-4o', token_type=TokenType.COMPLETION)\n" - ), - removal="1.0", - ) - token_type = TokenType.COMPLETION model_name = model_name.lower() if ".ft-" in model_name: model_name = model_name.split(".ft-")[0] + "-azure-finetuned" @@ -381,7 +366,6 @@ def standardize_model_name( def get_openai_token_cost_for_model( model_name: str, num_tokens: int, - is_completion: bool = False, *, token_type: TokenType = TokenType.PROMPT, ) -> float: @@ -391,24 +375,11 @@ def get_openai_token_cost_for_model( Args: model_name: Name of the model num_tokens: Number of tokens. - is_completion: Whether the model is used for completion or not. - Defaults to False. Deprecated in favor of ``token_type``. token_type: Token type. Defaults to ``TokenType.PROMPT``. Returns: Cost in USD. """ - if is_completion: - warn_deprecated( - since="0.3.13", - message=( - "is_completion is deprecated. Use token_type instead. Example:\n\n" - "from langchain_community.callbacks.openai_info import TokenType\n\n" - "get_openai_token_cost_for_model('gpt-4o', 10, token_type=TokenType.COMPLETION)\n" # noqa: E501 - ), - removal="1.0", - ) - token_type = TokenType.COMPLETION model_name = standardize_model_name(model_name, token_type=token_type) if model_name not in MODEL_COST_PER_1K_TOKENS: raise ValueError( diff --git a/libs/community/langchain_community/chains/graph_qa/neptune_cypher.py b/libs/community/langchain_community/chains/graph_qa/neptune_cypher.py deleted file mode 100644 index 7318962b6..000000000 --- a/libs/community/langchain_community/chains/graph_qa/neptune_cypher.py +++ /dev/null @@ -1,254 +0,0 @@ -from __future__ import annotations - -import re -from typing import Any, Dict, List, Optional - -from langchain_classic.chains.base import Chain -from langchain_classic.chains.llm import LLMChain -from langchain_classic.chains.prompt_selector import ConditionalPromptSelector -from langchain_core._api.deprecation import deprecated -from langchain_core.callbacks import CallbackManagerForChainRun -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts.base import BasePromptTemplate -from pydantic import Field - -from langchain_community.chains.graph_qa.prompts import ( - CYPHER_QA_PROMPT, - NEPTUNE_OPENCYPHER_GENERATION_PROMPT, - NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT, -) -from langchain_community.graphs import BaseNeptuneGraph - -INTERMEDIATE_STEPS_KEY = "intermediate_steps" - - -def trim_query(query: str) -> str: - """Trim the query to only include Cypher keywords.""" - keywords = ( - "CALL", - "CREATE", - "DELETE", - "DETACH", - "LIMIT", - "MATCH", - "MERGE", - "OPTIONAL", - "ORDER", - "REMOVE", - "RETURN", - "SET", - "SKIP", - "UNWIND", - "WITH", - "WHERE", - "//", - ) - - lines = query.split("\n") - new_query = "" - - for line in lines: - if line.strip().upper().startswith(keywords): - new_query += line + "\n" - - return new_query - - -def extract_cypher(text: str) -> str: - """Extract Cypher code from text using Regex.""" - # The pattern to find Cypher code enclosed in triple backticks - pattern = r"```(.*?)```" - - # Find all matches in the input text - matches = re.findall(pattern, text, re.DOTALL) - - return matches[0] if matches else text - - -def use_simple_prompt(llm: BaseLanguageModel) -> bool: - """Decides whether to use the simple prompt""" - if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore[attr-defined] - return True - - # Bedrock anthropic - if hasattr(llm, "model_id") and "anthropic" in llm.model_id: - return True - - return False - - -PROMPT_SELECTOR = ConditionalPromptSelector( - default_prompt=NEPTUNE_OPENCYPHER_GENERATION_PROMPT, - conditionals=[(use_simple_prompt, NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT)], -) - - -@deprecated( - since="0.3.15", - removal="1.0", - alternative_import="langchain_aws.create_neptune_opencypher_qa_chain", -) -class NeptuneOpenCypherQAChain(Chain): - """Chain for question-answering against a Neptune graph - by generating openCypher statements. - - *Security note*: Make sure that the database connection uses credentials - that are narrowly-scoped to only include necessary permissions. - Failure to do so may result in data corruption or loss, since the calling - code may attempt commands that would result in deletion, mutation - of data if appropriately prompted or reading sensitive data if such - data is present in the database. - The best way to guard against such negative outcomes is to (as appropriate) - limit the permissions granted to the credentials used with this tool. - - See https://python.langchain.com/docs/security for more information. - - Example: - .. code-block:: python - - chain = NeptuneOpenCypherQAChain.from_llm( - llm=llm, - graph=graph - ) - response = chain.run(query) - """ - - graph: BaseNeptuneGraph = Field(exclude=True) - cypher_generation_chain: LLMChain - qa_chain: LLMChain - input_key: str = "query" #: :meta private: - output_key: str = "result" #: :meta private: - top_k: int = 10 - return_intermediate_steps: bool = False - """Whether or not to return the intermediate steps along with the final answer.""" - return_direct: bool = False - """Whether or not to return the result of querying the graph directly.""" - extra_instructions: Optional[str] = None - """Extra instructions by the appended to the query generation prompt.""" - - allow_dangerous_requests: bool = False - """Forced user opt-in to acknowledge that the chain can make dangerous requests. - - *Security note*: Make sure that the database connection uses credentials - that are narrowly-scoped to only include necessary permissions. - Failure to do so may result in data corruption or loss, since the calling - code may attempt commands that would result in deletion, mutation - of data if appropriately prompted or reading sensitive data if such - data is present in the database. - The best way to guard against such negative outcomes is to (as appropriate) - limit the permissions granted to the credentials used with this tool. - - See https://python.langchain.com/docs/security for more information. - """ - - def __init__(self, **kwargs: Any) -> None: - """Initialize the chain.""" - super().__init__(**kwargs) - if self.allow_dangerous_requests is not True: - raise ValueError( - "In order to use this chain, you must acknowledge that it can make " - "dangerous requests by setting `allow_dangerous_requests` to `True`." - "You must narrowly scope the permissions of the database connection " - "to only include necessary permissions. Failure to do so may result " - "in data corruption or loss or reading sensitive data if such data is " - "present in the database." - "Only use this chain if you understand the risks and have taken the " - "necessary precautions. " - "See https://python.langchain.com/docs/security for more information." - ) - - @property - def input_keys(self) -> List[str]: - """Return the input keys. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Return the output keys. - - :meta private: - """ - _output_keys = [self.output_key] - return _output_keys - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - *, - qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, - cypher_prompt: Optional[BasePromptTemplate] = None, - extra_instructions: Optional[str] = None, - **kwargs: Any, - ) -> NeptuneOpenCypherQAChain: - """Initialize from LLM.""" - qa_chain = LLMChain(llm=llm, prompt=qa_prompt) - - _cypher_prompt = cypher_prompt or PROMPT_SELECTOR.get_prompt(llm) - cypher_generation_chain = LLMChain(llm=llm, prompt=_cypher_prompt) - - return cls( - qa_chain=qa_chain, - cypher_generation_chain=cypher_generation_chain, - extra_instructions=extra_instructions, - **kwargs, - ) - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, Any]: - """Generate Cypher statement, use it to look up in db and answer question.""" - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - callbacks = _run_manager.get_child() - question = inputs[self.input_key] - - intermediate_steps: List = [] - - generated_cypher = self.cypher_generation_chain.run( - { - "question": question, - "schema": self.graph.get_schema, - "extra_instructions": self.extra_instructions or "", - }, - callbacks=callbacks, - ) - - # Extract Cypher code if it is wrapped in backticks - generated_cypher = extract_cypher(generated_cypher) - generated_cypher = trim_query(generated_cypher) - - _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) - _run_manager.on_text( - generated_cypher, color="green", end="\n", verbose=self.verbose - ) - - intermediate_steps.append({"query": generated_cypher}) - - context = self.graph.query(generated_cypher) - - if self.return_direct: - final_result = context - else: - _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) - _run_manager.on_text( - str(context), color="green", end="\n", verbose=self.verbose - ) - - intermediate_steps.append({"context": context}) - - result = self.qa_chain( - {"question": question, "context": context}, - callbacks=callbacks, - ) - final_result = result[self.qa_chain.output_key] - - chain_result: Dict[str, Any] = {self.output_key: final_result} - if self.return_intermediate_steps: - chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps - - return chain_result diff --git a/libs/community/langchain_community/chains/graph_qa/neptune_sparql.py b/libs/community/langchain_community/chains/graph_qa/neptune_sparql.py deleted file mode 100644 index 60a35eab2..000000000 --- a/libs/community/langchain_community/chains/graph_qa/neptune_sparql.py +++ /dev/null @@ -1,242 +0,0 @@ -""" -Question answering over an RDF or OWL graph using SPARQL. -""" - -from __future__ import annotations - -from typing import Any, Dict, List, Optional - -from langchain_classic.chains.base import Chain -from langchain_classic.chains.llm import LLMChain -from langchain_core._api.deprecation import deprecated -from langchain_core.callbacks.manager import CallbackManagerForChainRun -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts.base import BasePromptTemplate -from langchain_core.prompts.prompt import PromptTemplate -from pydantic import Field - -from langchain_community.chains.graph_qa.prompts import SPARQL_QA_PROMPT -from langchain_community.graphs import NeptuneRdfGraph - -INTERMEDIATE_STEPS_KEY = "intermediate_steps" - -SPARQL_GENERATION_TEMPLATE = """ -Task: Generate a SPARQL SELECT statement for querying a graph database. -For instance, to find all email addresses of John Doe, the following -query in backticks would be suitable: -``` -PREFIX foaf: -SELECT ?email -WHERE {{ - ?person foaf:name "John Doe" . - ?person foaf:mbox ?email . -}} -``` -Instructions: -Use only the node types and properties provided in the schema. -Do not use any node types and properties that are not explicitly provided. -Include all necessary prefixes. - -Examples: - -Schema: -{schema} -Note: Be as concise as possible. -Do not include any explanations or apologies in your responses. -Do not respond to any questions that ask for anything else than -for you to construct a SPARQL query. -Do not include any text except the SPARQL query generated. - -The question is: -{prompt}""" - -SPARQL_GENERATION_PROMPT = PromptTemplate( - input_variables=["schema", "prompt"], template=SPARQL_GENERATION_TEMPLATE -) - - -def extract_sparql(query: str) -> str: - """Extract SPARQL code from a text. - - Args: - query: Text to extract SPARQL code from. - - Returns: - SPARQL code extracted from the text. - """ - query = query.strip() - querytoks = query.split("```") - if len(querytoks) == 3: - query = querytoks[1] - - if query.startswith("sparql"): - query = query[6:] - elif query.startswith("") and query.endswith(""): - query = query[8:-9] - return query - - -@deprecated( - since="0.3.15", - removal="1.0", - alternative_import="langchain_aws.create_neptune_sparql_qa_chain", -) -class NeptuneSparqlQAChain(Chain): - """Chain for question-answering against a Neptune graph - by generating SPARQL statements. - - *Security note*: Make sure that the database connection uses credentials - that are narrowly-scoped to only include necessary permissions. - Failure to do so may result in data corruption or loss, since the calling - code may attempt commands that would result in deletion, mutation - of data if appropriately prompted or reading sensitive data if such - data is present in the database. - The best way to guard against such negative outcomes is to (as appropriate) - limit the permissions granted to the credentials used with this tool. - - See https://python.langchain.com/docs/security for more information. - - Example: - .. code-block:: python - - chain = NeptuneSparqlQAChain.from_llm( - llm=llm, - graph=graph - ) - response = chain.invoke(query) - """ - - graph: NeptuneRdfGraph = Field(exclude=True) - sparql_generation_chain: LLMChain - qa_chain: LLMChain - input_key: str = "query" #: :meta private: - output_key: str = "result" #: :meta private: - top_k: int = 10 - return_intermediate_steps: bool = False - """Whether or not to return the intermediate steps along with the final answer.""" - return_direct: bool = False - """Whether or not to return the result of querying the graph directly.""" - extra_instructions: Optional[str] = None - """Extra instructions by the appended to the query generation prompt.""" - - allow_dangerous_requests: bool = False - """Forced user opt-in to acknowledge that the chain can make dangerous requests. - - *Security note*: Make sure that the database connection uses credentials - that are narrowly-scoped to only include necessary permissions. - Failure to do so may result in data corruption or loss, since the calling - code may attempt commands that would result in deletion, mutation - of data if appropriately prompted or reading sensitive data if such - data is present in the database. - The best way to guard against such negative outcomes is to (as appropriate) - limit the permissions granted to the credentials used with this tool. - - See https://python.langchain.com/docs/security for more information. - """ - - def __init__(self, **kwargs: Any) -> None: - """Initialize the chain.""" - super().__init__(**kwargs) - if self.allow_dangerous_requests is not True: - raise ValueError( - "In order to use this chain, you must acknowledge that it can make " - "dangerous requests by setting `allow_dangerous_requests` to `True`." - "You must narrowly scope the permissions of the database connection " - "to only include necessary permissions. Failure to do so may result " - "in data corruption or loss or reading sensitive data if such data is " - "present in the database." - "Only use this chain if you understand the risks and have taken the " - "necessary precautions. " - "See https://python.langchain.com/docs/security for more information." - ) - - @property - def input_keys(self) -> List[str]: - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - _output_keys = [self.output_key] - return _output_keys - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - *, - qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT, - sparql_prompt: BasePromptTemplate = SPARQL_GENERATION_PROMPT, - examples: Optional[str] = None, - **kwargs: Any, - ) -> NeptuneSparqlQAChain: - """Initialize from LLM.""" - qa_chain = LLMChain(llm=llm, prompt=qa_prompt) - template_to_use = SPARQL_GENERATION_TEMPLATE - if examples: - template_to_use = template_to_use.replace( - "Examples:", "Examples: " + examples - ) - sparql_prompt = PromptTemplate( - input_variables=["schema", "prompt"], template=template_to_use - ) - sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_prompt) - - return cls( - qa_chain=qa_chain, - sparql_generation_chain=sparql_generation_chain, - examples=examples, - **kwargs, - ) - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - """ - Generate SPARQL query, use it to retrieve a response from the gdb and answer - the question. - """ - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - callbacks = _run_manager.get_child() - prompt = inputs[self.input_key] - - intermediate_steps: List = [] - - generated_sparql = self.sparql_generation_chain.run( - {"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks - ) - - # Extract SPARQL - generated_sparql = extract_sparql(generated_sparql) - - _run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose) - _run_manager.on_text( - generated_sparql, color="green", end="\n", verbose=self.verbose - ) - - intermediate_steps.append({"query": generated_sparql}) - - context = self.graph.query(generated_sparql) - - if self.return_direct: - final_result = context - else: - _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) - _run_manager.on_text( - str(context), color="green", end="\n", verbose=self.verbose - ) - - intermediate_steps.append({"context": context}) - - result = self.qa_chain( - {"prompt": prompt, "context": context}, - callbacks=callbacks, - ) - final_result = result[self.qa_chain.output_key] - - chain_result: Dict[str, Any] = {self.output_key: final_result} - if self.return_intermediate_steps: - chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps - - return chain_result diff --git a/libs/community/langchain_community/chat_loaders/__init__.py b/libs/community/langchain_community/chat_loaders/__init__.py index 255b9b159..026e87c2b 100644 --- a/libs/community/langchain_community/chat_loaders/__init__.py +++ b/libs/community/langchain_community/chat_loaders/__init__.py @@ -29,9 +29,6 @@ FolderFacebookMessengerChatLoader, SingleFileFacebookMessengerChatLoader, ) - from langchain_community.chat_loaders.gmail import ( - GMailLoader, - ) from langchain_community.chat_loaders.imessage import ( IMessageChatLoader, ) @@ -52,7 +49,6 @@ __all__ = [ "BaseChatLoader", "FolderFacebookMessengerChatLoader", - "GMailLoader", "IMessageChatLoader", "LangSmithDatasetChatLoader", "LangSmithRunChatLoader", @@ -65,7 +61,6 @@ _module_lookup = { "BaseChatLoader": "langchain_core.chat_loaders", "FolderFacebookMessengerChatLoader": "langchain_community.chat_loaders.facebook_messenger", # noqa: E501 - "GMailLoader": "langchain_community.chat_loaders.gmail", "IMessageChatLoader": "langchain_community.chat_loaders.imessage", "LangSmithDatasetChatLoader": "langchain_community.chat_loaders.langsmith", "LangSmithRunChatLoader": "langchain_community.chat_loaders.langsmith", diff --git a/libs/community/langchain_community/chat_loaders/gmail.py b/libs/community/langchain_community/chat_loaders/gmail.py deleted file mode 100644 index 03dff9c24..000000000 --- a/libs/community/langchain_community/chat_loaders/gmail.py +++ /dev/null @@ -1,117 +0,0 @@ -import base64 -import re -from typing import Any, Iterator - -from langchain_core._api.deprecation import deprecated -from langchain_core.chat_loaders import BaseChatLoader -from langchain_core.chat_sessions import ChatSession -from langchain_core.messages import HumanMessage - - -def _extract_email_content(msg: Any) -> HumanMessage: - from_email = None - for values in msg["payload"]["headers"]: - name = values["name"] - if name == "From": - from_email = values["value"] - if from_email is None: - raise ValueError - for part in msg["payload"]["parts"]: - if part["mimeType"] == "text/plain": - data = part["body"]["data"] - data = base64.urlsafe_b64decode(data).decode("utf-8") - # Regular expression to split the email body at the first - # occurrence of a line that starts with "On ... wrote:" - pattern = re.compile(r"\r\nOn .+(\r\n)*wrote:\r\n") - # Split the email body and extract the first part - newest_response = re.split(pattern, data)[0] - message = HumanMessage( - content=newest_response, additional_kwargs={"sender": from_email} - ) - return message - raise ValueError - - -def _get_message_data(service: Any, message: Any) -> ChatSession: - msg = service.users().messages().get(userId="me", id=message["id"]).execute() - message_content = _extract_email_content(msg) - in_reply_to = None - email_data = msg["payload"]["headers"] - for values in email_data: - name = values["name"] - if name == "In-Reply-To": - in_reply_to = values["value"] - if in_reply_to is None: - raise ValueError - - thread_id = msg["threadId"] - - thread = service.users().threads().get(userId="me", id=thread_id).execute() - messages = thread["messages"] - - response_email = None - for message in messages: - email_data = message["payload"]["headers"] - for values in email_data: - if values["name"] == "Message-ID": - message_id = values["value"] - if message_id == in_reply_to: - response_email = message - if response_email is None: - raise ValueError - starter_content = _extract_email_content(response_email) - return ChatSession(messages=[starter_content, message_content]) - - -@deprecated( - since="0.0.32", - removal="1.0", - alternative_import="langchain_google_community.GMailLoader", -) -class GMailLoader(BaseChatLoader): - """Load data from `GMail`. - - There are many ways you could want to load data from GMail. - This loader is currently fairly opinionated in how to do so. - The way it does it is it first looks for all messages that you have sent. - It then looks for messages where you are responding to a previous email. - It then fetches that previous email, and creates a training example - of that email, followed by your email. - - Note that there are clear limitations here. For example, - all examples created are only looking at the previous email for context. - - To use: - - - Set up a Google Developer Account: - Go to the Google Developer Console, create a project, - and enable the Gmail API for that project. - This will give you a credentials.json file that you'll need later. - """ - - def __init__(self, creds: Any, n: int = 100, raise_error: bool = False) -> None: - super().__init__() - self.creds = creds - self.n = n - self.raise_error = raise_error - - def lazy_load(self) -> Iterator[ChatSession]: - from googleapiclient.discovery import build - - service = build("gmail", "v1", credentials=self.creds) - results = ( - service.users() - .messages() - .list(userId="me", labelIds=["SENT"], maxResults=self.n) - .execute() - ) - messages = results.get("messages", []) - for message in messages: - try: - yield _get_message_data(service, message) - except Exception as e: - # TODO: handle errors better - if self.raise_error: - raise e - else: - pass diff --git a/libs/community/langchain_community/chat_message_histories/__init__.py b/libs/community/langchain_community/chat_message_histories/__init__.py index fc20cacac..0a1932fd3 100644 --- a/libs/community/langchain_community/chat_message_histories/__init__.py +++ b/libs/community/langchain_community/chat_message_histories/__init__.py @@ -19,9 +19,6 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from langchain_community.chat_message_histories.astradb import ( - AstraDBChatMessageHistory, - ) from langchain_community.chat_message_histories.cassandra import ( CassandraChatMessageHistory, ) @@ -49,12 +46,6 @@ from langchain_community.chat_message_histories.momento import ( MomentoChatMessageHistory, ) - from langchain_community.chat_message_histories.mongodb import ( - MongoDBChatMessageHistory, - ) - from langchain_community.chat_message_histories.neo4j import ( - Neo4jChatMessageHistory, - ) from langchain_community.chat_message_histories.postgres import ( PostgresChatMessageHistory, ) @@ -90,7 +81,6 @@ ) __all__ = [ - "AstraDBChatMessageHistory", "CassandraChatMessageHistory", "ChatMessageHistory", "CosmosDBChatMessageHistory", @@ -99,8 +89,6 @@ "FileChatMessageHistory", "FirestoreChatMessageHistory", "MomentoChatMessageHistory", - "MongoDBChatMessageHistory", - "Neo4jChatMessageHistory", "PostgresChatMessageHistory", "RedisChatMessageHistory", "RocksetChatMessageHistory", @@ -116,7 +104,6 @@ ] _module_lookup = { - "AstraDBChatMessageHistory": "langchain_community.chat_message_histories.astradb", "CassandraChatMessageHistory": "langchain_community.chat_message_histories.cassandra", # noqa: E501 "ChatMessageHistory": "langchain_community.chat_message_histories.in_memory", "CosmosDBChatMessageHistory": "langchain_community.chat_message_histories.cosmos_db", # noqa: E501 @@ -125,8 +112,6 @@ "FileChatMessageHistory": "langchain_community.chat_message_histories.file", "FirestoreChatMessageHistory": "langchain_community.chat_message_histories.firestore", # noqa: E501 "MomentoChatMessageHistory": "langchain_community.chat_message_histories.momento", - "MongoDBChatMessageHistory": "langchain_community.chat_message_histories.mongodb", - "Neo4jChatMessageHistory": "langchain_community.chat_message_histories.neo4j", "PostgresChatMessageHistory": "langchain_community.chat_message_histories.postgres", "RedisChatMessageHistory": "langchain_community.chat_message_histories.redis", "RocksetChatMessageHistory": "langchain_community.chat_message_histories.rocksetdb", diff --git a/libs/community/langchain_community/chat_message_histories/astradb.py b/libs/community/langchain_community/chat_message_histories/astradb.py deleted file mode 100644 index f64339356..000000000 --- a/libs/community/langchain_community/chat_message_histories/astradb.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Astra DB - based chat message history, based on astrapy.""" - -from __future__ import annotations - -import json -import time -from typing import TYPE_CHECKING, List, Optional, Sequence - -from langchain_community.utilities.astradb import ( - SetupMode, - _AstraDBCollectionEnvironment, -) - -if TYPE_CHECKING: - from astrapy.db import AstraDB, AsyncAstraDB - -from langchain_core._api.deprecation import deprecated -from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.messages import ( - BaseMessage, - message_to_dict, - messages_from_dict, -) - -DEFAULT_COLLECTION_NAME = "langchain_message_store" - - -@deprecated( - since="0.0.25", - removal="1.0", - alternative_import="langchain_astradb.AstraDBChatMessageHistory", -) -class AstraDBChatMessageHistory(BaseChatMessageHistory): - def __init__( - self, - *, - session_id: str, - collection_name: str = DEFAULT_COLLECTION_NAME, - token: Optional[str] = None, - api_endpoint: Optional[str] = None, - astra_db_client: Optional[AstraDB] = None, - async_astra_db_client: Optional[AsyncAstraDB] = None, - namespace: Optional[str] = None, - setup_mode: SetupMode = SetupMode.SYNC, - pre_delete_collection: bool = False, - ) -> None: - """Chat message history that stores history in Astra DB. - - Args: - session_id: arbitrary key that is used to store the messages - of a single chat session. - collection_name: name of the Astra DB collection to create/use. - token: API token for Astra DB usage. - api_endpoint: full URL to the API endpoint, - such as "https://-us-east1.apps.astra.datastax.com". - astra_db_client: *alternative to token+api_endpoint*, - you can pass an already-created 'astrapy.db.AstraDB' instance. - async_astra_db_client: *alternative to token+api_endpoint*, - you can pass an already-created 'astrapy.db.AsyncAstraDB' instance. - namespace: namespace (aka keyspace) where the - collection is created. Defaults to the database's "default namespace". - setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or - OFF). - pre_delete_collection: whether to delete the collection - before creating it. If False and the collection already exists, - the collection will be used as is. - """ - self.astra_env = _AstraDBCollectionEnvironment( - collection_name=collection_name, - token=token, - api_endpoint=api_endpoint, - astra_db_client=astra_db_client, - async_astra_db_client=async_astra_db_client, - namespace=namespace, - setup_mode=setup_mode, - pre_delete_collection=pre_delete_collection, - ) - - self.collection = self.astra_env.collection - self.async_collection = self.astra_env.async_collection - - self.session_id = session_id - self.collection_name = collection_name - - @property - def messages(self) -> List[BaseMessage]: - """Retrieve all session messages from DB""" - self.astra_env.ensure_db_setup() - message_blobs = [ - doc["body_blob"] - for doc in sorted( - self.collection.paginated_find( - filter={ - "session_id": self.session_id, - }, - projection={ - "timestamp": 1, - "body_blob": 1, - }, - ), - key=lambda _doc: _doc["timestamp"], - ) - ] - items = [json.loads(message_blob) for message_blob in message_blobs] - messages = messages_from_dict(items) - return messages - - @messages.setter - def messages(self, messages: List[BaseMessage]) -> None: - raise NotImplementedError("Use add_messages instead") - - async def aget_messages(self) -> List[BaseMessage]: - await self.astra_env.aensure_db_setup() - docs = self.async_collection.paginated_find( - filter={ - "session_id": self.session_id, - }, - projection={ - "timestamp": 1, - "body_blob": 1, - }, - ) - sorted_docs = sorted( - [doc async for doc in docs], - key=lambda _doc: _doc["timestamp"], - ) - message_blobs = [doc["body_blob"] for doc in sorted_docs] - items = [json.loads(message_blob) for message_blob in message_blobs] - messages = messages_from_dict(items) - return messages - - def add_messages(self, messages: Sequence[BaseMessage]) -> None: - self.astra_env.ensure_db_setup() - docs = [ - { - "timestamp": time.time(), - "session_id": self.session_id, - "body_blob": json.dumps(message_to_dict(message)), - } - for message in messages - ] - self.collection.chunked_insert_many(docs) - - async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: - await self.astra_env.aensure_db_setup() - docs = [ - { - "timestamp": time.time(), - "session_id": self.session_id, - "body_blob": json.dumps(message_to_dict(message)), - } - for message in messages - ] - await self.async_collection.chunked_insert_many(docs) - - def clear(self) -> None: - self.astra_env.ensure_db_setup() - self.collection.delete_many(filter={"session_id": self.session_id}) - - async def aclear(self) -> None: - await self.astra_env.aensure_db_setup() - await self.async_collection.delete_many(filter={"session_id": self.session_id}) diff --git a/libs/community/langchain_community/chat_message_histories/mongodb.py b/libs/community/langchain_community/chat_message_histories/mongodb.py deleted file mode 100644 index 4ab688765..000000000 --- a/libs/community/langchain_community/chat_message_histories/mongodb.py +++ /dev/null @@ -1,101 +0,0 @@ -import json -import logging -from typing import List - -from langchain_core._api.deprecation import deprecated -from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.messages import ( - BaseMessage, - message_to_dict, - messages_from_dict, -) - -logger = logging.getLogger(__name__) - -DEFAULT_DBNAME = "chat_history" -DEFAULT_COLLECTION_NAME = "message_store" - - -@deprecated( - since="0.0.25", - removal="1.0", - alternative_import="langchain_mongodb.MongoDBChatMessageHistory", -) -class MongoDBChatMessageHistory(BaseChatMessageHistory): - """Chat message history that stores history in MongoDB. - - Args: - connection_string: connection string to connect to MongoDB - session_id: arbitrary key that is used to store the messages - of a single chat session. - database_name: name of the database to use - collection_name: name of the collection to use - create_index: whether to create an index with name SessionId. Set to False if - such an index already exists. - """ - - def __init__( - self, - connection_string: str, - session_id: str, - database_name: str = DEFAULT_DBNAME, - collection_name: str = DEFAULT_COLLECTION_NAME, - create_index: bool = True, - ): - from pymongo import MongoClient, errors - - self.connection_string = connection_string - self.session_id = session_id - self.database_name = database_name - self.collection_name = collection_name - - try: - self.client: MongoClient = MongoClient(connection_string) - except errors.ConnectionFailure as error: - logger.error(error) - - self.db = self.client[database_name] - self.collection = self.db[collection_name] - if create_index: - self.collection.create_index("SessionId") - - @property - def messages(self) -> List[BaseMessage]: # type: ignore[override] - """Retrieve the messages from MongoDB""" - from pymongo import errors - - try: - cursor = self.collection.find({"SessionId": self.session_id}) - except errors.OperationFailure as error: - logger.error(error) - - if cursor: - items = [json.loads(document["History"]) for document in cursor] - else: - items = [] - - messages = messages_from_dict(items) - return messages - - def add_message(self, message: BaseMessage) -> None: - """Append the message to the record in MongoDB""" - from pymongo import errors - - try: - self.collection.insert_one( - { - "SessionId": self.session_id, - "History": json.dumps(message_to_dict(message)), - } - ) - except errors.WriteError as err: - logger.error(err) - - def clear(self) -> None: - """Clear session memory from MongoDB""" - from pymongo import errors - - try: - self.collection.delete_many({"SessionId": self.session_id}) - except errors.WriteError as err: - logger.error(err) diff --git a/libs/community/langchain_community/chat_message_histories/neo4j.py b/libs/community/langchain_community/chat_message_histories/neo4j.py deleted file mode 100644 index 9d2cb3178..000000000 --- a/libs/community/langchain_community/chat_message_histories/neo4j.py +++ /dev/null @@ -1,140 +0,0 @@ -from typing import List, Optional, Union - -from langchain_core._api.deprecation import deprecated -from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.messages import BaseMessage, messages_from_dict -from langchain_core.utils import get_from_dict_or_env - -from langchain_community.graphs import Neo4jGraph - - -@deprecated( - since="0.3.8", - removal="1.0", - alternative_import="langchain_neo4j.Neo4jChatMessageHistory", -) -class Neo4jChatMessageHistory(BaseChatMessageHistory): - """Chat message history stored in a Neo4j database.""" - - def __init__( - self, - session_id: Union[str, int], - url: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - database: str = "neo4j", - node_label: str = "Session", - window: int = 3, - *, - graph: Optional[Neo4jGraph] = None, - ): - try: - import neo4j - except ImportError: - raise ImportError( - "Could not import neo4j python package. " - "Please install it with `pip install neo4j`." - ) - - # Make sure session id is not null - if not session_id: - raise ValueError("Please ensure that the session_id parameter is provided") - - # Graph object takes precedent over env or input params - if graph: - self._driver = graph._driver - self._database = graph._database - else: - # Handle if the credentials are environment variables - url = get_from_dict_or_env({"url": url}, "url", "NEO4J_URI") - username = get_from_dict_or_env( - {"username": username}, "username", "NEO4J_USERNAME" - ) - password = get_from_dict_or_env( - {"password": password}, "password", "NEO4J_PASSWORD" - ) - database = get_from_dict_or_env( - {"database": database}, "database", "NEO4J_DATABASE", "neo4j" - ) - - self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) - self._database = database - # Verify connection - try: - self._driver.verify_connectivity() - except neo4j.exceptions.ServiceUnavailable: - raise ValueError( - "Could not connect to Neo4j database. " - "Please ensure that the url is correct" - ) - except neo4j.exceptions.AuthError: - raise ValueError( - "Could not connect to Neo4j database. " - "Please ensure that the username and password are correct" - ) - self._session_id = session_id - self._node_label = node_label - self._window = window - # Create session node - self._driver.execute_query( - f"MERGE (s:`{self._node_label}` {{id:$session_id}})", - {"session_id": self._session_id}, - ).summary - - @property - def messages(self) -> List[BaseMessage]: - """Retrieve the messages from Neo4j""" - query = ( - f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) " - "WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT*0.." - f"{self._window * 2}]-() WITH p, length(p) AS length " - "ORDER BY length DESC LIMIT 1 UNWIND reverse(nodes(p)) AS node " - "RETURN {data:{content: node.content}, type:node.type} AS result" - ) - records, _, _ = self._driver.execute_query( - query, {"session_id": self._session_id} - ) - - messages = messages_from_dict([el["result"] for el in records]) - return messages - - @messages.setter - def messages(self, messages: List[BaseMessage]) -> None: - raise NotImplementedError( - "Direct assignment to 'messages' is not allowed." - " Use the 'add_messages' instead." - ) - - def add_message(self, message: BaseMessage) -> None: - """Append the message to the record in Neo4j""" - query = ( - f"MATCH (s:`{self._node_label}`) WHERE s.id = $session_id " - "OPTIONAL MATCH (s)-[lm:LAST_MESSAGE]->(last_message) " - "CREATE (s)-[:LAST_MESSAGE]->(new:Message) " - "SET new += {type:$type, content:$content} " - "WITH new, lm, last_message WHERE last_message IS NOT NULL " - "CREATE (last_message)-[:NEXT]->(new) " - "DELETE lm" - ) - self._driver.execute_query( - query, - { - "type": message.type, - "content": message.content, - "session_id": self._session_id, - }, - ).summary - - def clear(self) -> None: - """Clear session memory from Neo4j""" - query = ( - f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) " - "WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT]-() " - "WITH p, length(p) AS length ORDER BY length DESC LIMIT 1 " - "UNWIND nodes(p) as node DETACH DELETE node;" - ) - self._driver.execute_query(query, {"session_id": self._session_id}).summary - - def __del__(self) -> None: - if self._driver: - self._driver.close() diff --git a/libs/community/langchain_community/chat_message_histories/sql.py b/libs/community/langchain_community/chat_message_histories/sql.py index d26804ec7..a3e86c08b 100644 --- a/libs/community/langchain_community/chat_message_histories/sql.py +++ b/libs/community/langchain_community/chat_message_histories/sql.py @@ -14,7 +14,6 @@ cast, ) -from langchain_core._api import deprecated, warn_deprecated from sqlalchemy import Column, Integer, Text, delete, select try: @@ -146,15 +145,9 @@ class SQLChatMessageHistory(BaseChatMessageHistory): """ - @property - @deprecated("0.2.2", removal="1.0", alternative="session_maker") - def Session(self) -> Union[scoped_session, async_sessionmaker]: - return self.session_maker - def __init__( self, session_id: str, - connection_string: Optional[str] = None, table_name: str = "message_store", session_id_field_name: str = "session_id", custom_message_converter: Optional[BaseMessageConverter] = None, @@ -166,8 +159,6 @@ def __init__( Args: session_id: Indicates the id of the same session. - connection_string: String parameter configuration for connecting - to the database. table_name: Table name used to save data. session_id_field_name: The name of field of `session_id`. custom_message_converter: Custom message converter for converting @@ -177,21 +168,6 @@ def __init__( engine_args: Additional configuration for creating database engines. async_mode: Whether it is an asynchronous connection. """ - assert not (connection_string and connection), ( - "connection_string and connection are mutually exclusive" - ) - if connection_string: - global _warned_once_already - if not _warned_once_already: - warn_deprecated( - since="0.2.2", - removal="1.0", - name="connection_string", - alternative="connection", - ) - _warned_once_already = True - connection = connection_string - self.connection_string = connection_string if isinstance(connection, str): self.async_mode = async_mode if async_mode: diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index 9c83bdecb..09f24fe11 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -1,53 +1,26 @@ """**Chat Models** are a variation on language models. -While Chat Models use language models under the hood, the interface they expose -is a bit different. Rather than expose a "text in, text out" API, they expose -an interface where "chat messages" are the inputs and outputs. - -**Class hierarchy:** - -.. code-block:: - - BaseLanguageModel --> BaseChatModel --> # Examples: ChatOpenAI, ChatGooglePalm - -**Main helpers:** - -.. code-block:: - - AIMessage, BaseMessage, HumanMessage -""" # noqa: E501 +While Chat Models use language models under the hood, the interface they expose is a bit +different. Rather than expose a "text in, text out" API, they expose an interface where +"chat messages" are the inputs and outputs. +""" import importlib from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from langchain_community.chat_models.anthropic import ( - ChatAnthropic, - ) from langchain_community.chat_models.anyscale import ( ChatAnyscale, ) - from langchain_community.chat_models.azure_openai import ( - AzureChatOpenAI, - ) from langchain_community.chat_models.baichuan import ( ChatBaichuan, ) from langchain_community.chat_models.baidu_qianfan_endpoint import ( QianfanChatEndpoint, ) - from langchain_community.chat_models.bedrock import ( - BedrockChat, - ) - from langchain_community.chat_models.cohere import ( - ChatCohere, - ) from langchain_community.chat_models.coze import ( ChatCoze, ) - from langchain_community.chat_models.databricks import ( - ChatDatabricks, - ) from langchain_community.chat_models.deepinfra import ( ChatDeepInfra, ) @@ -61,24 +34,15 @@ from langchain_community.chat_models.fake import ( FakeListChatModel, ) - from langchain_community.chat_models.fireworks import ( - ChatFireworks, - ) from langchain_community.chat_models.friendli import ( ChatFriendli, ) - from langchain_community.chat_models.gigachat import ( - GigaChat, - ) from langchain_community.chat_models.google_palm import ( ChatGooglePalm, ) from langchain_community.chat_models.gpt_router import ( GPTRouter, ) - from langchain_community.chat_models.huggingface import ( - ChatHuggingFace, - ) from langchain_community.chat_models.human import ( HumanInputChatModel, ) @@ -97,12 +61,6 @@ from langchain_community.chat_models.konko import ( ChatKonko, ) - from langchain_community.chat_models.litellm import ( - ChatLiteLLM, - ) - from langchain_community.chat_models.litellm_router import ( - ChatLiteLLMRouter, - ) from langchain_community.chat_models.llama_edge import ( LlamaEdgeChatService, ) @@ -137,19 +95,10 @@ ChatOCIGenAI, # noqa: F401 ) from langchain_community.chat_models.octoai import ChatOctoAI - from langchain_community.chat_models.ollama import ( - ChatOllama, - ) - from langchain_community.chat_models.openai import ( - ChatOpenAI, - ) from langchain_community.chat_models.outlines import ChatOutlines from langchain_community.chat_models.pai_eas_endpoint import ( PaiEasChatEndpoint, ) - from langchain_community.chat_models.perplexity import ( - ChatPerplexity, - ) from langchain_community.chat_models.premai import ( ChatPremAI, ) @@ -159,16 +108,9 @@ from langchain_community.chat_models.reka import ( ChatReka, ) - from langchain_community.chat_models.sambanova import ( - ChatSambaNovaCloud, - ChatSambaStudio, - ) from langchain_community.chat_models.snowflake import ( ChatSnowflakeCortex, ) - from langchain_community.chat_models.solar import ( - SolarChat, - ) from langchain_community.chat_models.sparkllm import ( ChatSparkLLM, ) @@ -176,9 +118,6 @@ from langchain_community.chat_models.tongyi import ( ChatTongyi, ) - from langchain_community.chat_models.vertexai import ( - ChatVertexAI, - ) from langchain_community.chat_models.volcengine_maas import ( VolcEngineMaasChat, ) @@ -195,29 +134,20 @@ ChatZhipuAI, ) __all__ = [ - "AzureChatOpenAI", - "BedrockChat", - "ChatAnthropic", "ChatAnyscale", "ChatBaichuan", "ChatClovaX", - "ChatCohere", "ChatCoze", "ChatOctoAI", - "ChatDatabricks", "ChatDeepInfra", "ChatEdenAI", "ChatEverlyAI", - "ChatFireworks", "ChatFriendli", "ChatGooglePalm", - "ChatHuggingFace", "ChatHunyuan", "ChatJavelinAIGateway", "ChatKinetica", "ChatKonko", - "ChatLiteLLM", - "ChatLiteLLMRouter", "ChatMLX", "ChatMLflowAIGateway", "ChatMaritalk", @@ -227,18 +157,12 @@ "ChatOCIModelDeployment", "ChatOCIModelDeploymentVLLM", "ChatOCIModelDeploymentTGI", - "ChatOllama", - "ChatOpenAI", "ChatOutlines", - "ChatPerplexity", "ChatReka", "ChatPremAI", - "ChatSambaNovaCloud", - "ChatSambaStudio", "ChatSparkLLM", "ChatSnowflakeCortex", "ChatTongyi", - "ChatVertexAI", "ChatYandexGPT", "ChatYuan2", "ChatZhipuAI", @@ -246,7 +170,6 @@ "ErnieBotChat", "FakeListChatModel", "GPTRouter", - "GigaChat", "HumanInputChatModel", "JinaChat", "LlamaEdgeChatService", @@ -255,35 +178,25 @@ "PaiEasChatEndpoint", "PromptLayerChatOpenAI", "QianfanChatEndpoint", - "SolarChat", "VolcEngineMaasChat", "ChatYi", ] _module_lookup = { - "AzureChatOpenAI": "langchain_community.chat_models.azure_openai", - "BedrockChat": "langchain_community.chat_models.bedrock", - "ChatAnthropic": "langchain_community.chat_models.anthropic", "ChatAnyscale": "langchain_community.chat_models.anyscale", "ChatBaichuan": "langchain_community.chat_models.baichuan", "ChatClovaX": "langchain_community.chat_models.naver", - "ChatCohere": "langchain_community.chat_models.cohere", "ChatCoze": "langchain_community.chat_models.coze", - "ChatDatabricks": "langchain_community.chat_models.databricks", "ChatDeepInfra": "langchain_community.chat_models.deepinfra", "ChatEverlyAI": "langchain_community.chat_models.everlyai", "ChatEdenAI": "langchain_community.chat_models.edenai", - "ChatFireworks": "langchain_community.chat_models.fireworks", "ChatFriendli": "langchain_community.chat_models.friendli", "ChatGooglePalm": "langchain_community.chat_models.google_palm", - "ChatHuggingFace": "langchain_community.chat_models.huggingface", "ChatHunyuan": "langchain_community.chat_models.hunyuan", "ChatJavelinAIGateway": "langchain_community.chat_models.javelin_ai_gateway", "ChatKinetica": "langchain_community.chat_models.kinetica", "ChatKonko": "langchain_community.chat_models.konko", - "ChatLiteLLM": "langchain_community.chat_models.litellm", - "ChatLiteLLMRouter": "langchain_community.chat_models.litellm_router", "ChatMLflowAIGateway": "langchain_community.chat_models.mlflow_ai_gateway", "ChatMLX": "langchain_community.chat_models.mlx", "ChatMaritalk": "langchain_community.chat_models.maritalk", @@ -294,24 +207,17 @@ "ChatOCIModelDeployment": "langchain_community.chat_models.oci_data_science", "ChatOCIModelDeploymentVLLM": "langchain_community.chat_models.oci_data_science", "ChatOCIModelDeploymentTGI": "langchain_community.chat_models.oci_data_science", - "ChatOllama": "langchain_community.chat_models.ollama", - "ChatOpenAI": "langchain_community.chat_models.openai", "ChatOutlines": "langchain_community.chat_models.outlines", "ChatReka": "langchain_community.chat_models.reka", - "ChatPerplexity": "langchain_community.chat_models.perplexity", - "ChatSambaNovaCloud": "langchain_community.chat_models.sambanova", - "ChatSambaStudio": "langchain_community.chat_models.sambanova", "ChatSnowflakeCortex": "langchain_community.chat_models.snowflake", "ChatSparkLLM": "langchain_community.chat_models.sparkllm", "ChatTongyi": "langchain_community.chat_models.tongyi", - "ChatVertexAI": "langchain_community.chat_models.vertexai", "ChatYandexGPT": "langchain_community.chat_models.yandex", "ChatYuan2": "langchain_community.chat_models.yuan2", "ChatZhipuAI": "langchain_community.chat_models.zhipuai", "ErnieBotChat": "langchain_community.chat_models.ernie", "FakeListChatModel": "langchain_community.chat_models.fake", "GPTRouter": "langchain_community.chat_models.gpt_router", - "GigaChat": "langchain_community.chat_models.gigachat", "HumanInputChatModel": "langchain_community.chat_models.human", "JinaChat": "langchain_community.chat_models.jinachat", "LlamaEdgeChatService": "langchain_community.chat_models.llama_edge", @@ -319,7 +225,6 @@ "MoonshotChat": "langchain_community.chat_models.moonshot", "PaiEasChatEndpoint": "langchain_community.chat_models.pai_eas_endpoint", "PromptLayerChatOpenAI": "langchain_community.chat_models.promptlayer_openai", - "SolarChat": "langchain_community.chat_models.solar", "QianfanChatEndpoint": "langchain_community.chat_models.baidu_qianfan_endpoint", "VolcEngineMaasChat": "langchain_community.chat_models.volcengine_maas", "ChatPremAI": "langchain_community.chat_models.premai", diff --git a/libs/community/langchain_community/chat_models/anthropic.py b/libs/community/langchain_community/chat_models/anthropic.py deleted file mode 100644 index cd7160eb5..000000000 --- a/libs/community/langchain_community/chat_models/anthropic.py +++ /dev/null @@ -1,234 +0,0 @@ -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast - -from langchain_core._api.deprecation import deprecated -from langchain_core.callbacks import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain_core.language_models.chat_models import ( - BaseChatModel, - agenerate_from_stream, - generate_from_stream, -) -from langchain_core.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - ChatMessage, - HumanMessage, - SystemMessage, -) -from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.prompt_values import PromptValue -from pydantic import ConfigDict - -from langchain_community.llms.anthropic import _AnthropicCommon - - -def _convert_one_message_to_text( - message: BaseMessage, - human_prompt: str, - ai_prompt: str, -) -> str: - content = cast(str, message.content) - if isinstance(message, ChatMessage): - message_text = f"\n\n{message.role.capitalize()}: {content}" - elif isinstance(message, HumanMessage): - message_text = f"{human_prompt} {content}" - elif isinstance(message, AIMessage): - message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemMessage): - message_text = content - else: - raise ValueError(f"Got unknown type {message}") - return message_text - - -def convert_messages_to_prompt_anthropic( - messages: List[BaseMessage], - *, - human_prompt: str = "\n\nHuman:", - ai_prompt: str = "\n\nAssistant:", -) -> str: - """Format a list of messages into a full prompt for the Anthropic model - Args: - messages (List[BaseMessage]): List of BaseMessage to combine. - human_prompt (str, optional): Human prompt tag. Defaults to "\n\nHuman:". - ai_prompt (str, optional): AI prompt tag. Defaults to "\n\nAssistant:". - Returns: - str: Combined string with necessary human_prompt and ai_prompt tags. - """ - - messages = messages.copy() # don't mutate the original list - if not isinstance(messages[-1], AIMessage): - messages.append(AIMessage(content="")) - - text = "".join( - _convert_one_message_to_text(message, human_prompt, ai_prompt) - for message in messages - ) - - # trim off the trailing ' ' that might come from the "Assistant: " - return text.rstrip() - - -@deprecated( - since="0.0.28", - removal="1.0", - alternative_import="langchain_anthropic.ChatAnthropic", -) -class ChatAnthropic(BaseChatModel, _AnthropicCommon): - """`Anthropic` chat large language models. - - To use, you should have the ``anthropic`` python package installed, and the - environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass - it as a named parameter to the constructor. - - Example: - .. code-block:: python - - import anthropic - from langchain_community.chat_models import ChatAnthropic - model = ChatAnthropic(model="", anthropic_api_key="my-api-key") - """ - - model_config = ConfigDict( - populate_by_name=True, - arbitrary_types_allowed=True, - ) - - @property - def lc_secrets(self) -> Dict[str, str]: - return {"anthropic_api_key": "ANTHROPIC_API_KEY"} - - @property - def _llm_type(self) -> str: - """Return type of chat model.""" - return "anthropic-chat" - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether this model can be serialized by Langchain.""" - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" - return ["langchain", "chat_models", "anthropic"] - - def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str: - """Format a list of messages into a full prompt for the Anthropic model - Args: - messages (List[BaseMessage]): List of BaseMessage to combine. - Returns: - str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags. - """ - prompt_params = {} - if self.HUMAN_PROMPT: - prompt_params["human_prompt"] = self.HUMAN_PROMPT - if self.AI_PROMPT: - prompt_params["ai_prompt"] = self.AI_PROMPT - return convert_messages_to_prompt_anthropic(messages=messages, **prompt_params) - - def convert_prompt(self, prompt: PromptValue) -> str: - return self._convert_messages_to_prompt(prompt.to_messages()) - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - prompt = self._convert_messages_to_prompt(messages) - params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs} - if stop: - params["stop_sequences"] = stop - - stream_resp = self.client.completions.create(**params, stream=True) - for data in stream_resp: - delta = data.completion - chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) - if run_manager: - run_manager.on_llm_new_token(delta, chunk=chunk) - yield chunk - - async def _astream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - prompt = self._convert_messages_to_prompt(messages) - params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs} - if stop: - params["stop_sequences"] = stop - - stream_resp = await self.async_client.completions.create(**params, stream=True) - async for data in stream_resp: - delta = data.completion - chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) - if run_manager: - await run_manager.on_llm_new_token(delta, chunk=chunk) - yield chunk - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - if self.streaming: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return generate_from_stream(stream_iter) - prompt = self._convert_messages_to_prompt( - messages, - ) - params: Dict[str, Any] = { - "prompt": prompt, - **self._default_params, - **kwargs, - } - if stop: - params["stop_sequences"] = stop - response = self.client.completions.create(**params) - completion = response.completion - message = AIMessage(content=completion) - return ChatResult(generations=[ChatGeneration(message=message)]) - - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - if self.streaming: - stream_iter = self._astream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return await agenerate_from_stream(stream_iter) - prompt = self._convert_messages_to_prompt( - messages, - ) - params: Dict[str, Any] = { - "prompt": prompt, - **self._default_params, - **kwargs, - } - if stop: - params["stop_sequences"] = stop - response = await self.async_client.completions.create(**params) - completion = response.completion - message = AIMessage(content=completion) - return ChatResult(generations=[ChatGeneration(message=message)]) - - def get_num_tokens(self, text: str) -> int: - """Calculate number of tokens.""" - if not self.count_tokens: - raise NameError("Please ensure the anthropic package is loaded") - return self.count_tokens(text) diff --git a/libs/community/langchain_community/chat_models/anyscale.py b/libs/community/langchain_community/chat_models/anyscale.py index 1e1a12e6d..7dd3e0d39 100644 --- a/libs/community/langchain_community/chat_models/anyscale.py +++ b/libs/community/langchain_community/chat_models/anyscale.py @@ -27,7 +27,6 @@ from langchain_community.adapters.openai import convert_message_to_dict from langchain_community.chat_models.openai import ( ChatOpenAI, - _import_tiktoken, ) from langchain_community.utils.openai import is_openai_v1 @@ -40,6 +39,18 @@ DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" +def _import_tiktoken() -> Any: + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "This is needed in order to calculate get_token_ids. " + "Please install it with `pip install tiktoken`." + ) + return tiktoken + + class ChatAnyscale(ChatOpenAI): """`Anyscale` Chat large language models. diff --git a/libs/community/langchain_community/chat_models/azure_openai.py b/libs/community/langchain_community/chat_models/azure_openai.py deleted file mode 100644 index 4f52d4e56..000000000 --- a/libs/community/langchain_community/chat_models/azure_openai.py +++ /dev/null @@ -1,293 +0,0 @@ -"""Azure OpenAI chat wrapper.""" - -from __future__ import annotations - -import logging -import os -import warnings -from typing import Any, Awaitable, Callable, Dict, List, Union - -from langchain_core._api.deprecation import deprecated -from langchain_core.outputs import ChatResult -from langchain_core.utils import get_from_dict_or_env, pre_init -from pydantic import BaseModel, Field - -from langchain_community.chat_models.openai import ChatOpenAI -from langchain_community.utils.openai import is_openai_v1 - -logger = logging.getLogger(__name__) - - -@deprecated( - since="0.0.10", - removal="1.0", - alternative_import="langchain_openai.AzureChatOpenAI", -) -class AzureChatOpenAI(ChatOpenAI): - """`Azure OpenAI` Chat Completion API. - - To use this class you - must have a deployed model on Azure OpenAI. Use `deployment_name` in the - constructor to refer to the "Model deployment name" in the Azure portal. - - In addition, you should have the ``openai`` python package installed, and the - following environment variables set or passed in constructor in lower case: - - ``AZURE_OPENAI_API_KEY`` - - ``AZURE_OPENAI_ENDPOINT`` - - ``AZURE_OPENAI_AD_TOKEN`` - - ``OPENAI_API_VERSION`` - - ``OPENAI_PROXY`` - - For example, if you have `gpt-35-turbo` deployed, with the deployment name - `35-turbo-dev`, the constructor should look like: - - .. code-block:: python - - AzureChatOpenAI( - azure_deployment="35-turbo-dev", - openai_api_version="2023-05-15", - ) - - Be aware the API version may change. - - You can also specify the version of the model using ``model_version`` constructor - parameter, as Azure OpenAI doesn't return model version with the response. - - Default is empty. When you specify the version, it will be appended to the - model name in the response. Setting correct version will help you to calculate the - cost properly. Model version is not validated, so make sure you set it correctly - to get the correct cost. - - Any parameters that are valid to be passed to the openai.create call can be passed - in, even if not explicitly saved on this class. - """ - - azure_endpoint: Union[str, None] = None - """Your Azure endpoint, including the resource. - - Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided. - - Example: `https://example-resource.azure.openai.com/` - """ - deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment") - """A model deployment. - - If given sets the base client URL to include `/deployments/{azure_deployment}`. - Note: this means you won't be able to use non-deployment endpoints. - """ - openai_api_version: str = Field(default="", alias="api_version") - """Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" - openai_api_key: Union[str, None] = Field(default=None, alias="api_key") - """Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided.""" - azure_ad_token: Union[str, None] = None - """Your Azure Active Directory token. - - Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided. - - For more: - https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id. - """ - azure_ad_token_provider: Union[Callable[[], str], None] = None - """A function that returns an Azure Active Directory token. - - Will be invoked on every sync request. For async requests, - will be invoked if `azure_ad_async_token_provider` is not provided. - """ - azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None - """A function that returns an Azure Active Directory token. - - Will be invoked on every async request. - """ - model_version: str = "" - """Legacy, for openai<1.0.0 support.""" - openai_api_type: str = "" - """Legacy, for openai<1.0.0 support.""" - validate_base_url: bool = True - """For backwards compatibility. If legacy val openai_api_base is passed in, try to - infer if it is a base_url or azure_endpoint and update accordingly. - """ - - @classmethod - def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" - return ["langchain", "chat_models", "azure_openai"] - - @pre_init - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" - if values["n"] < 1: - raise ValueError("n must be at least 1.") - if values["n"] > 1 and values["streaming"]: - raise ValueError("n must be 1 when streaming.") - - # Check OPENAI_KEY for backwards compatibility. - # TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using - # other forms of azure credentials. - values["openai_api_key"] = ( - values["openai_api_key"] - or os.getenv("AZURE_OPENAI_API_KEY") - or os.getenv("OPENAI_API_KEY") - ) - values["openai_api_base"] = values["openai_api_base"] or os.getenv( - "OPENAI_API_BASE" - ) - values["openai_api_version"] = values["openai_api_version"] or os.getenv( - "OPENAI_API_VERSION" - ) - # Check OPENAI_ORGANIZATION for backwards compatibility. - values["openai_organization"] = ( - values["openai_organization"] - or os.getenv("OPENAI_ORG_ID") - or os.getenv("OPENAI_ORGANIZATION") - ) - values["azure_endpoint"] = values["azure_endpoint"] or os.getenv( - "AZURE_OPENAI_ENDPOINT" - ) - values["azure_ad_token"] = values["azure_ad_token"] or os.getenv( - "AZURE_OPENAI_AD_TOKEN" - ) - - values["openai_api_type"] = get_from_dict_or_env( - values, "openai_api_type", "OPENAI_API_TYPE", default="azure" - ) - values["openai_proxy"] = get_from_dict_or_env( - values, "openai_proxy", "OPENAI_PROXY", default="" - ) - - try: - import openai - - except ImportError: - raise ImportError( - "Could not import openai python package. " - "Please install it with `pip install openai`." - ) - if is_openai_v1(): - # For backwards compatibility. Before openai v1, no distinction was made - # between azure_endpoint and base_url (openai_api_base). - openai_api_base = values["openai_api_base"] - if openai_api_base and values["validate_base_url"]: - if "/openai" not in openai_api_base: - values["openai_api_base"] = ( - values["openai_api_base"].rstrip("/") + "/openai" - ) - warnings.warn( - "As of openai>=1.0.0, Azure endpoints should be specified via " - f"the `azure_endpoint` param not `openai_api_base` " - f"(or alias `base_url`). Updating `openai_api_base` from " - f"{openai_api_base} to {values['openai_api_base']}." - ) - if values["deployment_name"]: - warnings.warn( - "As of openai>=1.0.0, if `deployment_name` (or alias " - "`azure_deployment`) is specified then " - "`openai_api_base` (or alias `base_url`) should not be. " - "Instead use `deployment_name` (or alias `azure_deployment`) " - "and `azure_endpoint`." - ) - if values["deployment_name"] not in values["openai_api_base"]: - warnings.warn( - "As of openai>=1.0.0, if `openai_api_base` " - "(or alias `base_url`) is specified it is expected to be " - "of the form " - "https://example-resource.azure.openai.com/openai/deployments/example-deployment. " # noqa: E501 - f"Updating {openai_api_base} to " - f"{values['openai_api_base']}." - ) - values["openai_api_base"] += ( - "/deployments/" + values["deployment_name"] - ) - values["deployment_name"] = None - client_params = { - "api_version": values["openai_api_version"], - "azure_endpoint": values["azure_endpoint"], - "azure_deployment": values["deployment_name"], - "api_key": values["openai_api_key"], - "azure_ad_token": values["azure_ad_token"], - "azure_ad_token_provider": values["azure_ad_token_provider"], - "organization": values["openai_organization"], - "base_url": values["openai_api_base"], - "timeout": values["request_timeout"], - "max_retries": values["max_retries"], - "default_headers": { - **(values["default_headers"] or {}), - "User-Agent": "langchain-comm-python-azure-openai", - }, - "default_query": values["default_query"], - "http_client": values["http_client"], - } - values["client"] = openai.AzureOpenAI(**client_params).chat.completions - - azure_ad_async_token_provider = values["azure_ad_async_token_provider"] - - if azure_ad_async_token_provider: - client_params["azure_ad_token_provider"] = azure_ad_async_token_provider - - values["async_client"] = openai.AsyncAzureOpenAI( - **client_params - ).chat.completions - else: - values["client"] = openai.ChatCompletion - return values - - @property - def _default_params(self) -> Dict[str, Any]: - """Get the default parameters for calling OpenAI API.""" - if is_openai_v1(): - return super()._default_params - else: - return { - **super()._default_params, - "engine": self.deployment_name, - } - - @property - def _identifying_params(self) -> Dict[str, Any]: - """Get the identifying parameters.""" - return {**self._default_params} - - @property - def _client_params(self) -> Dict[str, Any]: - """Get the config params used for the openai client.""" - if is_openai_v1(): - return super()._client_params - else: - return { - **super()._client_params, - "api_type": self.openai_api_type, - "api_version": self.openai_api_version, - } - - @property - def _llm_type(self) -> str: - return "azure-openai-chat" - - @property - def lc_attributes(self) -> Dict[str, Any]: - return { - "openai_api_type": self.openai_api_type, - "openai_api_version": self.openai_api_version, - } - - def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: - if not isinstance(response, dict): - response = response.dict() - for res in response["choices"]: - if res.get("finish_reason", None) == "content_filter": - raise ValueError( - "Azure has not provided the response due to a content filter " - "being triggered" - ) - chat_result = super()._create_chat_result(response) - - if "model" in response: - model = response["model"] - if self.model_version: - model = f"{model}-{self.model_version}" - - if chat_result.llm_output is not None and isinstance( - chat_result.llm_output, dict - ): - chat_result.llm_output["model_name"] = model - - return chat_result diff --git a/libs/community/langchain_community/chat_models/bedrock.py b/libs/community/langchain_community/chat_models/bedrock.py deleted file mode 100644 index 086a4d461..000000000 --- a/libs/community/langchain_community/chat_models/bedrock.py +++ /dev/null @@ -1,337 +0,0 @@ -import re -from collections import defaultdict -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union - -from langchain_core._api.deprecation import deprecated -from langchain_core.callbacks import ( - CallbackManagerForLLMRun, -) -from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - ChatMessage, - HumanMessage, - SystemMessage, -) -from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from pydantic import ConfigDict - -from langchain_community.chat_models.anthropic import ( - convert_messages_to_prompt_anthropic, -) -from langchain_community.chat_models.meta import convert_messages_to_prompt_llama -from langchain_community.llms.bedrock import BedrockBase -from langchain_community.utilities.anthropic import ( - get_num_tokens_anthropic, - get_token_ids_anthropic, -) - - -def _convert_one_message_to_text_mistral(message: BaseMessage) -> str: - if isinstance(message, ChatMessage): - message_text = f"\n\n{message.role.capitalize()}: {message.content}" - elif isinstance(message, HumanMessage): - message_text = f"[INST] {message.content} [/INST]" - elif isinstance(message, AIMessage): - message_text = f"{message.content}" - elif isinstance(message, SystemMessage): - message_text = f"<> {message.content} <>" - else: - raise ValueError(f"Got unknown type {message}") - return message_text - - -def convert_messages_to_prompt_mistral(messages: List[BaseMessage]) -> str: - """Convert a list of messages to a prompt for mistral.""" - return "\n".join( - [_convert_one_message_to_text_mistral(message) for message in messages] - ) - - -def _format_image(image_url: str) -> Dict: - """ - Formats an image of format data:image/jpeg;base64,{b64_string} - to a dict for anthropic api - - { - "type": "base64", - "media_type": "image/jpeg", - "data": "/9j/4AAQSkZJRg...", - } - - And throws an error if it's not a b64 image - """ - regex = r"^data:(?Pimage/.+);base64,(?P.+)$" - match = re.match(regex, image_url) - if match is None: - raise ValueError( - "Anthropic only supports base64-encoded images currently." - " Example: data:image/png;base64,'/9j/4AAQSk'..." - ) - return { - "type": "base64", - "media_type": match.group("media_type"), - "data": match.group("data"), - } - - -def _format_anthropic_messages( - messages: List[BaseMessage], -) -> Tuple[Optional[str], List[Dict]]: - """Format messages for anthropic.""" - - """ - [ - { - "role": _message_type_lookups[m.type], - "content": [_AnthropicMessageContent(text=m.content).dict()], - } - for m in messages - ] - """ - system: Optional[str] = None - formatted_messages: List[Dict] = [] - for i, message in enumerate(messages): - if message.type == "system": - if i != 0: - raise ValueError("System message must be at beginning of message list.") - if not isinstance(message.content, str): - raise ValueError( - "System message must be a string, " - f"instead was: {type(message.content)}" - ) - system = message.content - continue - - role = _message_type_lookups[message.type] - content: Union[str, List[Dict]] - - if not isinstance(message.content, str): - # parse as dict - assert isinstance(message.content, list), ( - "Anthropic message content must be str or list of dicts" - ) - - # populate content - content = [] - for item in message.content: - if isinstance(item, str): - content.append( - { - "type": "text", - "text": item, - } - ) - elif isinstance(item, dict): - if "type" not in item: - raise ValueError("Dict content item must have a type key") - if item["type"] == "image_url": - # convert format - source = _format_image(item["image_url"]["url"]) - content.append( - { - "type": "image", - "source": source, - } - ) - else: - content.append(item) - else: - raise ValueError( - f"Content items must be str or dict, instead was: {type(item)}" - ) - else: - content = message.content - - formatted_messages.append( - { - "role": role, - "content": content, - } - ) - return system, formatted_messages - - -class ChatPromptAdapter: - """Adapter class to prepare the inputs from Langchain to prompt format - that Chat model expects. - """ - - @classmethod - def convert_messages_to_prompt( - cls, provider: str, messages: List[BaseMessage] - ) -> str: - if provider == "anthropic": - prompt = convert_messages_to_prompt_anthropic(messages=messages) - elif provider == "meta": - prompt = convert_messages_to_prompt_llama(messages=messages) - elif provider == "mistral": - prompt = convert_messages_to_prompt_mistral(messages=messages) - elif provider == "amazon": - prompt = convert_messages_to_prompt_anthropic( - messages=messages, - human_prompt="\n\nUser:", - ai_prompt="\n\nBot:", - ) - else: - raise NotImplementedError( - f"Provider {provider} model does not support chat." - ) - return prompt - - @classmethod - def format_messages( - cls, provider: str, messages: List[BaseMessage] - ) -> Tuple[Optional[str], List[Dict]]: - if provider == "anthropic": - return _format_anthropic_messages(messages) - - raise NotImplementedError( - f"Provider {provider} not supported for format_messages" - ) - - -_message_type_lookups = { - "human": "user", - "ai": "assistant", - "AIMessageChunk": "assistant", - "HumanMessageChunk": "user", - "function": "user", -} - - -@deprecated( - since="0.0.34", removal="1.0", alternative_import="langchain_aws.ChatBedrock" -) -class BedrockChat(BaseChatModel, BedrockBase): - """Chat model that uses the Bedrock API.""" - - @property - def _llm_type(self) -> str: - """Return type of chat model.""" - return "amazon_bedrock_chat" - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether this model can be serialized by Langchain.""" - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" - return ["langchain", "chat_models", "bedrock"] - - @property - def lc_attributes(self) -> Dict[str, Any]: - attributes: Dict[str, Any] = {} - - if self.region_name: - attributes["region_name"] = self.region_name - - return attributes - - model_config = ConfigDict( - extra="forbid", - ) - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - provider = self._get_provider() - prompt, system, formatted_messages = None, None, None - - if provider == "anthropic": - system, formatted_messages = ChatPromptAdapter.format_messages( - provider, messages - ) - else: - prompt = ChatPromptAdapter.convert_messages_to_prompt( - provider=provider, messages=messages - ) - - for chunk in self._prepare_input_and_invoke_stream( - prompt=prompt, - system=system, - messages=formatted_messages, - stop=stop, - run_manager=run_manager, - **kwargs, - ): - delta = chunk.text - yield ChatGenerationChunk(message=AIMessageChunk(content=delta)) - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - completion = "" - llm_output: Dict[str, Any] = {"model_id": self.model_id} - - if self.streaming: - for chunk in self._stream(messages, stop, run_manager, **kwargs): - completion += chunk.text - else: - provider = self._get_provider() - prompt, system, formatted_messages = None, None, None - params: Dict[str, Any] = {**kwargs} - - if provider == "anthropic": - system, formatted_messages = ChatPromptAdapter.format_messages( - provider, messages - ) - else: - prompt = ChatPromptAdapter.convert_messages_to_prompt( - provider=provider, messages=messages - ) - - if stop: - params["stop_sequences"] = stop - - completion, usage_info = self._prepare_input_and_invoke( - prompt=prompt, - stop=stop, - run_manager=run_manager, - system=system, - messages=formatted_messages, - **params, - ) - - llm_output["usage"] = usage_info - - return ChatResult( - generations=[ChatGeneration(message=AIMessage(content=completion))], - llm_output=llm_output, - ) - - def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: - final_usage: Dict[str, int] = defaultdict(int) - final_output = {} - for output in llm_outputs: - output = output or {} - usage = output.get("usage", {}) - for token_type, token_count in usage.items(): - final_usage[token_type] += token_count - final_output.update(output) - final_output["usage"] = final_usage - return final_output - - def get_num_tokens(self, text: str) -> int: - if self._model_is_anthropic: - return get_num_tokens_anthropic(text) - else: - return super().get_num_tokens(text) - - def get_token_ids(self, text: str) -> List[int]: - if self._model_is_anthropic: - return get_token_ids_anthropic(text) - else: - return super().get_token_ids(text) diff --git a/libs/community/langchain_community/chat_models/cloudflare_workersai.py b/libs/community/langchain_community/chat_models/cloudflare_workersai.py deleted file mode 100644 index ece4845cc..000000000 --- a/libs/community/langchain_community/chat_models/cloudflare_workersai.py +++ /dev/null @@ -1,256 +0,0 @@ -import logging -from operator import itemgetter -from typing import ( - Any, - Callable, - Dict, - List, - Literal, - Optional, - Sequence, - Type, - Union, - cast, -) -from uuid import uuid4 - -import requests -from langchain_classic.schema import AIMessage, ChatGeneration, ChatResult, HumanMessage -from langchain_core._api.deprecation import deprecated -from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.language_models import LanguageModelInput -from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import ( - AIMessageChunk, - BaseMessage, - SystemMessage, - ToolCall, - ToolMessage, -) -from langchain_core.messages.tool import tool_call -from langchain_core.output_parsers import ( - JsonOutputParser, - PydanticOutputParser, -) -from langchain_core.output_parsers.base import OutputParserLike -from langchain_core.output_parsers.openai_tools import ( - JsonOutputKeyToolsParser, - PydanticToolsParser, -) -from langchain_core.runnables import Runnable, RunnablePassthrough -from langchain_core.runnables.base import RunnableMap -from langchain_core.tools import BaseTool -from langchain_core.utils.function_calling import convert_to_openai_tool -from langchain_core.utils.pydantic import is_basemodel_subclass -from pydantic import BaseModel, Field - -# Initialize logging -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) -_logger = logging.getLogger(__name__) - - -def _is_pydantic_class(obj: Any) -> bool: - return isinstance(obj, type) and is_basemodel_subclass(obj) - - -def _convert_messages_to_cloudflare_messages( - messages: List[BaseMessage], -) -> List[Dict[str, Any]]: - """Convert LangChain messages to Cloudflare Workers AI format.""" - cloudflare_messages = [] - msg: Dict[str, Any] - for message in messages: - # Base structure for each message - msg = { - "role": "", - "content": message.content if isinstance(message.content, str) else "", - } - - # Determine role and additional fields based on message type - if isinstance(message, HumanMessage): - msg["role"] = "user" - elif isinstance(message, AIMessage): - msg["role"] = "assistant" - # If the AIMessage includes tool calls, format them as needed - if message.tool_calls: - tool_calls = [ - {"name": tool_call["name"], "arguments": tool_call["args"]} - for tool_call in message.tool_calls - ] - msg["tool_calls"] = tool_calls - elif isinstance(message, SystemMessage): - msg["role"] = "system" - elif isinstance(message, ToolMessage): - msg["role"] = "tool" - msg["tool_call_id"] = ( - message.tool_call_id - ) # Use tool_call_id if it's a ToolMessage - - # Add the formatted message to the list - cloudflare_messages.append(msg) - - return cloudflare_messages - - -def _get_tool_calls_from_response(response: requests.Response) -> List[ToolCall]: - """Get tool calls from ollama response.""" - tool_calls = [] - if "tool_calls" in response.json()["result"]: - for tc in response.json()["result"]["tool_calls"]: - tool_calls.append( - tool_call( - id=str(uuid4()), - name=tc["name"], - args=tc["arguments"], - ) - ) - return tool_calls - - -@deprecated( - since="0.3.23", - removal="1.0", - alternative_import="langchain_cloudflare.ChatCloudflareWorkersAI", -) -class ChatCloudflareWorkersAI(BaseChatModel): - """Custom chat model for Cloudflare Workers AI""" - - account_id: str = Field(...) - api_token: str = Field(...) - model: str = Field(...) - ai_gateway: str = "" - url: str = "" - base_url: str = "https://api.cloudflare.com/client/v4/accounts" - gateway_url: str = "https://gateway.ai.cloudflare.com/v1" - - def __init__(self, **kwargs: Any) -> None: - """Initialize with necessary credentials.""" - super().__init__(**kwargs) - if self.ai_gateway: - self.url = ( - f"{self.gateway_url}/{self.account_id}/" - f"{self.ai_gateway}/workers-ai/run/{self.model}" - ) - else: - self.url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}" - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - """Generate a response based on the messages provided.""" - formatted_messages = _convert_messages_to_cloudflare_messages(messages) - - headers = {"Authorization": f"Bearer {self.api_token}"} - prompt = "\n".join( - f"role: {msg['role']}, content: {msg['content']}" - + (f", tools: {msg['tool_calls']}" if "tool_calls" in msg else "") - + ( - f", tool_call_id: {msg['tool_call_id']}" - if "tool_call_id" in msg - else "" - ) - for msg in formatted_messages - ) - - # Initialize `data` with `prompt` - data = { - "prompt": prompt, - "tools": kwargs["tools"] if "tools" in kwargs else None, - **{key: value for key, value in kwargs.items() if key not in ["tools"]}, - } - - # Ensure `tools` is a list if it's included in `kwargs` - if data["tools"] is not None and not isinstance(data["tools"], list): - data["tools"] = [data["tools"]] - - _logger.info(f"Sending prompt to Cloudflare Workers AI: {data}") - - response = requests.post(self.url, headers=headers, json=data) - tool_calls = _get_tool_calls_from_response(response) - ai_message = AIMessage( - content=str(response.json()), tool_calls=cast(AIMessageChunk, tool_calls) - ) - chat_generation = ChatGeneration(message=ai_message) - return ChatResult(generations=[chat_generation]) - - def bind_tools( - self, - tools: Sequence[Union[Dict[str, Any], Type, Callable[..., Any], BaseTool]], - **kwargs: Any, - ) -> Runnable[LanguageModelInput, AIMessage]: - """Bind tools for use in model generation.""" - formatted_tools = [convert_to_openai_tool(tool) for tool in tools] - return super().bind(tools=formatted_tools, **kwargs) - - def with_structured_output( - self, - schema: Union[Dict, Type[BaseModel]], - *, - include_raw: bool = False, - method: Optional[Literal["json_mode", "function_calling"]] = "function_calling", - **kwargs: Any, - ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: - """Model wrapper that returns outputs formatted to match the given schema.""" - - _ = kwargs.pop("strict", None) - if kwargs: - raise ValueError(f"Received unsupported arguments {kwargs}") - is_pydantic_schema = _is_pydantic_class(schema) - if method == "json_schema": - # Some applications require that incompatible parameters (e.g., unsupported - # methods) be handled. - method = "function_calling" - if method == "function_calling": - if schema is None: - raise ValueError( - "schema must be specified when method is 'function_calling'. " - "Received None." - ) - tool_name = convert_to_openai_tool(schema)["function"]["name"] - llm = self.bind_tools([schema], tool_choice=tool_name) - if is_pydantic_schema: - output_parser: OutputParserLike = PydanticToolsParser( - tools=[schema], # type: ignore[list-item] - first_tool_only=True, - ) - else: - output_parser = JsonOutputKeyToolsParser( - key_name=tool_name, first_tool_only=True - ) - elif method == "json_mode": - llm = self.bind(response_format={"type": "json_object"}) - output_parser = ( - PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] - if is_pydantic_schema - else JsonOutputParser() - ) - else: - raise ValueError( - f"Unrecognized method argument. Expected one of 'function_calling' or " - f"'json_mode'. Received: '{method}'" - ) - - if include_raw: - parser_assign = RunnablePassthrough.assign( - parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None - ) - parser_none = RunnablePassthrough.assign(parsed=lambda _: None) - parser_with_fallback = parser_assign.with_fallbacks( - [parser_none], exception_key="parsing_error" - ) - return RunnableMap(raw=llm) | parser_with_fallback - else: - return llm | output_parser - - @property - def _llm_type(self) -> str: - """Return the type of the LLM (for Langchain compatibility).""" - return "cloudflare-workers-ai" diff --git a/libs/community/langchain_community/chat_models/cohere.py b/libs/community/langchain_community/chat_models/cohere.py deleted file mode 100644 index d2e8560a1..000000000 --- a/libs/community/langchain_community/chat_models/cohere.py +++ /dev/null @@ -1,251 +0,0 @@ -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional - -from langchain_core._api.deprecation import deprecated -from langchain_core.callbacks import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain_core.language_models.chat_models import ( - BaseChatModel, - agenerate_from_stream, - generate_from_stream, -) -from langchain_core.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - ChatMessage, - HumanMessage, - SystemMessage, -) -from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from pydantic import ConfigDict - -from langchain_community.llms.cohere import BaseCohere - - -def get_role(message: BaseMessage) -> str: - """Get the role of the message. - - Args: - message: The message. - - Returns: - The role of the message. - - Raises: - ValueError: If the message is of an unknown type. - """ - if isinstance(message, ChatMessage) or isinstance(message, HumanMessage): - return "User" - elif isinstance(message, AIMessage): - return "Chatbot" - elif isinstance(message, SystemMessage): - return "System" - else: - raise ValueError(f"Got unknown type {message}") - - -def get_cohere_chat_request( - messages: List[BaseMessage], - *, - connectors: Optional[List[Dict[str, str]]] = None, - **kwargs: Any, -) -> Dict[str, Any]: - """Get the request for the Cohere chat API. - - Args: - messages: The messages. - connectors: The connectors. - **kwargs: The keyword arguments. - - Returns: - The request for the Cohere chat API. - """ - documents = ( - None - if "source_documents" not in kwargs - else [ - { - "snippet": doc.page_content, - "id": doc.metadata.get("id") or f"doc-{str(i)}", - } - for i, doc in enumerate(kwargs["source_documents"]) - ] - ) - kwargs.pop("source_documents", None) - maybe_connectors = connectors if documents is None else None - - # by enabling automatic prompt truncation, the probability of request failure is - # reduced with minimal impact on response quality - prompt_truncation = ( - "AUTO" if documents is not None or connectors is not None else None - ) - - req = { - "message": messages[-1].content, - "chat_history": [ - {"role": get_role(x), "message": x.content} for x in messages[:-1] - ], - "documents": documents, - "connectors": maybe_connectors, - "prompt_truncation": prompt_truncation, - **kwargs, - } - - return {k: v for k, v in req.items() if v is not None} - - -@deprecated( - since="0.0.30", removal="1.0", alternative_import="langchain_cohere.ChatCohere" -) -class ChatCohere(BaseChatModel, BaseCohere): - """`Cohere` chat large language models. - - To use, you should have the ``cohere`` python package installed, and the - environment variable ``COHERE_API_KEY`` set with your API key, or pass - it as a named parameter to the constructor. - - Example: - .. code-block:: python - - from langchain_community.chat_models import ChatCohere - from langchain_core.messages import HumanMessage - - chat = ChatCohere(max_tokens=256, temperature=0.75) - - messages = [HumanMessage(content="knock knock")] - chat.invoke(messages) - """ - - model_config = ConfigDict( - populate_by_name=True, - arbitrary_types_allowed=True, - ) - - @property - def _llm_type(self) -> str: - """Return type of chat model.""" - return "cohere-chat" - - @property - def _default_params(self) -> Dict[str, Any]: - """Get the default parameters for calling Cohere API.""" - return { - "temperature": self.temperature, - } - - @property - def _identifying_params(self) -> Dict[str, Any]: - """Get the identifying parameters.""" - return {**{"model": self.model}, **self._default_params} - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - request = get_cohere_chat_request(messages, **self._default_params, **kwargs) - - if hasattr(self.client, "chat_stream"): # detect and support sdk v5 - stream = self.client.chat_stream(**request) - else: - stream = self.client.chat(**request, stream=True) - - for data in stream: - if data.event_type == "text-generation": - delta = data.text - chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) - if run_manager: - run_manager.on_llm_new_token(delta, chunk=chunk) - yield chunk - - async def _astream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - request = get_cohere_chat_request(messages, **self._default_params, **kwargs) - - if hasattr(self.async_client, "chat_stream"): # detect and support sdk v5 - stream = await self.async_client.chat_stream(**request) - else: - stream = await self.async_client.chat(**request, stream=True) - - async for data in stream: - if data.event_type == "text-generation": - delta = data.text - chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) - if run_manager: - await run_manager.on_llm_new_token(delta, chunk=chunk) - yield chunk - - def _get_generation_info(self, response: Any) -> Dict[str, Any]: - """Get the generation info from cohere API response.""" - return { - "documents": response.documents, - "citations": response.citations, - "search_results": response.search_results, - "search_queries": response.search_queries, - "token_count": response.token_count, - } - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - if self.streaming: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return generate_from_stream(stream_iter) - - request = get_cohere_chat_request(messages, **self._default_params, **kwargs) - response = self.client.chat(**request) - - message = AIMessage(content=response.text) - generation_info = None - if hasattr(response, "documents"): - generation_info = self._get_generation_info(response) - return ChatResult( - generations=[ - ChatGeneration(message=message, generation_info=generation_info) - ] - ) - - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - if self.streaming: - stream_iter = self._astream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return await agenerate_from_stream(stream_iter) - - request = get_cohere_chat_request(messages, **self._default_params, **kwargs) - response = self.client.chat(**request) - - message = AIMessage(content=response.text) - generation_info = None - if hasattr(response, "documents"): - generation_info = self._get_generation_info(response) - return ChatResult( - generations=[ - ChatGeneration(message=message, generation_info=generation_info) - ] - ) - - def get_num_tokens(self, text: str) -> int: - """Calculate number of tokens.""" - return len(self.client.tokenize(text=text).tokens) diff --git a/libs/community/langchain_community/chat_models/databricks.py b/libs/community/langchain_community/chat_models/databricks.py deleted file mode 100644 index dd459e4f2..000000000 --- a/libs/community/langchain_community/chat_models/databricks.py +++ /dev/null @@ -1,60 +0,0 @@ -import logging -from urllib.parse import urlparse - -from langchain_core._api import deprecated - -from langchain_community.chat_models.mlflow import ChatMlflow - -logger = logging.getLogger(__name__) - - -@deprecated( - since="0.3.3", - removal="1.0", - alternative_import="databricks_langchain.ChatDatabricks", -) -class ChatDatabricks(ChatMlflow): - """`Databricks` chat models API. - - To use, you should have the ``mlflow`` python package installed. - For more information, see https://mlflow.org/docs/latest/llms/deployments. - - Example: - .. code-block:: python - - from langchain_community.chat_models import ChatDatabricks - - chat_model = ChatDatabricks( - target_uri="databricks", - endpoint="databricks-llama-2-70b-chat", - temperature=0.1, - ) - - # single input invocation - print(chat_model.invoke("What is MLflow?").content) - - # single input invocation with streaming response - for chunk in chat_model.stream("What is MLflow?"): - print(chunk.content, end="|") - """ - - target_uri: str = "databricks" - """The target URI to use. Defaults to ``databricks``.""" - - @property - def _llm_type(self) -> str: - """Return type of chat model.""" - return "databricks-chat" - - @property - def _mlflow_extras(self) -> str: - return "" - - def _validate_uri(self) -> None: - if self.target_uri == "databricks": - return - - if urlparse(self.target_uri).scheme != "databricks": - raise ValueError( - "Invalid target URI. The target URI must be a valid databricks URI." - ) diff --git a/libs/community/langchain_community/chat_models/everlyai.py b/libs/community/langchain_community/chat_models/everlyai.py index b45b40a80..1ccdaaebb 100644 --- a/libs/community/langchain_community/chat_models/everlyai.py +++ b/libs/community/langchain_community/chat_models/everlyai.py @@ -25,7 +25,6 @@ from langchain_community.adapters.openai import convert_message_to_dict from langchain_community.chat_models.openai import ( ChatOpenAI, - _import_tiktoken, ) if TYPE_CHECKING: @@ -38,6 +37,18 @@ DEFAULT_MODEL = "meta-llama/Llama-2-7b-chat-hf" +def _import_tiktoken() -> Any: + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "This is needed in order to calculate get_token_ids. " + "Please install it with `pip install tiktoken`." + ) + return tiktoken + + class ChatEverlyAI(ChatOpenAI): """`EverlyAI` Chat large language models. diff --git a/libs/community/langchain_community/chat_models/fireworks.py b/libs/community/langchain_community/chat_models/fireworks.py deleted file mode 100644 index b0355f6e4..000000000 --- a/libs/community/langchain_community/chat_models/fireworks.py +++ /dev/null @@ -1,372 +0,0 @@ -from typing import ( - Any, - AsyncIterator, - Callable, - Dict, - Iterator, - List, - Optional, - Type, - Union, -) - -from langchain_core._api.deprecation import deprecated -from langchain_core.callbacks import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.language_models.llms import create_base_retry_decorator -from langchain_core.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - BaseMessageChunk, - ChatMessage, - ChatMessageChunk, - FunctionMessage, - FunctionMessageChunk, - HumanMessage, - HumanMessageChunk, - SystemMessage, - SystemMessageChunk, -) -from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.utils import convert_to_secret_str -from langchain_core.utils.env import get_from_dict_or_env -from pydantic import Field, SecretStr, model_validator - -from langchain_community.adapters.openai import convert_message_to_dict - - -def _convert_delta_to_message_chunk( - _dict: Any, default_class: Type[BaseMessageChunk] -) -> BaseMessageChunk: - """Convert a delta response to a message chunk.""" - role = _dict.role - content = _dict.content or "" - additional_kwargs: Dict = {} - - if role == "user" or default_class == HumanMessageChunk: - return HumanMessageChunk(content=content) - elif role == "assistant" or default_class == AIMessageChunk: - return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) - elif role == "system" or default_class == SystemMessageChunk: - return SystemMessageChunk(content=content) - elif role == "function" or default_class == FunctionMessageChunk: - return FunctionMessageChunk(content=content, name=_dict.name) - elif role or default_class == ChatMessageChunk: - return ChatMessageChunk(content=content, role=role) - else: - return default_class(content=content) # type: ignore[call-arg] - - -def convert_dict_to_message(_dict: Any) -> BaseMessage: - """Convert a dict response to a message.""" - role = _dict.role - content = _dict.content or "" - if role == "user": - return HumanMessage(content=content) - elif role == "assistant": - content = _dict.content - additional_kwargs: Dict = {} - return AIMessage(content=content, additional_kwargs=additional_kwargs) - elif role == "system": - return SystemMessage(content=content) - elif role == "function": - return FunctionMessage(content=content, name=_dict.name) - else: - return ChatMessage(content=content, role=role) - - -@deprecated( - since="0.0.26", - removal="1.0", - alternative_import="langchain_fireworks.ChatFireworks", -) -class ChatFireworks(BaseChatModel): - """Fireworks Chat models.""" - - model: str = "accounts/fireworks/models/llama-v2-7b-chat" - model_kwargs: dict = Field( - default_factory=lambda: { - "temperature": 0.7, - "max_tokens": 512, - "top_p": 1, - }.copy() - ) - fireworks_api_key: Optional[SecretStr] = None - max_retries: int = 20 - use_retry: bool = True - - @property - def lc_secrets(self) -> Dict[str, str]: - return {"fireworks_api_key": "FIREWORKS_API_KEY"} - - @classmethod - def is_lc_serializable(cls) -> bool: - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" - return ["langchain", "chat_models", "fireworks"] - - @model_validator(mode="before") - @classmethod - def validate_environment(cls, values: Dict) -> Any: - """Validate that api key in environment.""" - try: - import fireworks.client - except ImportError as e: - raise ImportError( - "Could not import fireworks-ai python package. " - "Please install it with `pip install fireworks-ai`." - ) from e - fireworks_api_key = convert_to_secret_str( - get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY") - ) - fireworks.client.api_key = fireworks_api_key.get_secret_value() - return values - - @property - def _llm_type(self) -> str: - """Return type of llm.""" - return "fireworks-chat" - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - message_dicts = self._create_message_dicts(messages) - - params = { - "model": self.model, - "messages": message_dicts, - **self.model_kwargs, - **kwargs, - } - response = completion_with_retry( - self, - self.use_retry, - run_manager=run_manager, - stop=stop, - **params, - ) - return self._create_chat_result(response) - - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - message_dicts = self._create_message_dicts(messages) - params = { - "model": self.model, - "messages": message_dicts, - **self.model_kwargs, - **kwargs, - } - response = await acompletion_with_retry( - self, self.use_retry, run_manager=run_manager, stop=stop, **params - ) - return self._create_chat_result(response) - - def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: - if llm_outputs[0] is None: - return {} - return llm_outputs[0] - - def _create_chat_result(self, response: Any) -> ChatResult: - generations = [] - for res in response.choices: - message = convert_dict_to_message(res.message) - gen = ChatGeneration( - message=message, - generation_info=dict(finish_reason=res.finish_reason), - ) - generations.append(gen) - llm_output = {"model": self.model} - return ChatResult(generations=generations, llm_output=llm_output) - - def _create_message_dicts( - self, messages: List[BaseMessage] - ) -> List[Dict[str, Any]]: - message_dicts = [convert_message_to_dict(m) for m in messages] - return message_dicts - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - message_dicts = self._create_message_dicts(messages) - default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk - params = { - "model": self.model, - "messages": message_dicts, - "stream": True, - **self.model_kwargs, - **kwargs, - } - for chunk in completion_with_retry( - self, self.use_retry, run_manager=run_manager, stop=stop, **params - ): - choice = chunk.choices[0] - chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class) - finish_reason = choice.finish_reason - generation_info = ( - dict(finish_reason=finish_reason) if finish_reason is not None else None - ) - default_chunk_class = chunk.__class__ - cg_chunk = ChatGenerationChunk( - message=chunk, generation_info=generation_info - ) - if run_manager: - run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk) - yield cg_chunk - - async def _astream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - message_dicts = self._create_message_dicts(messages) - default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk - params = { - "model": self.model, - "messages": message_dicts, - "stream": True, - **self.model_kwargs, - **kwargs, - } - async for chunk in await acompletion_with_retry_streaming( - self, self.use_retry, run_manager=run_manager, stop=stop, **params - ): - choice = chunk.choices[0] - chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class) - finish_reason = choice.finish_reason - generation_info = ( - dict(finish_reason=finish_reason) if finish_reason is not None else None - ) - default_chunk_class = chunk.__class__ - cg_chunk = ChatGenerationChunk( - message=chunk, generation_info=generation_info - ) - if run_manager: - await run_manager.on_llm_new_token(token=cg_chunk.text, chunk=cg_chunk) - yield cg_chunk - - -def conditional_decorator( - condition: bool, decorator: Callable[[Any], Any] -) -> Callable[[Any], Any]: - """Define conditional decorator. - - Args: - condition: The condition. - decorator: The decorator. - - Returns: - The decorated function. - """ - - def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]: - if condition: - return decorator(func) - return func - - return actual_decorator - - -def completion_with_retry( - llm: ChatFireworks, - use_retry: bool, - *, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, -) -> Any: - """Use tenacity to retry the completion call.""" - import fireworks.client - - retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) - - @conditional_decorator(use_retry, retry_decorator) - def _completion_with_retry(**kwargs: Any) -> Any: - """Use tenacity to retry the completion call.""" - return fireworks.client.ChatCompletion.create( - **kwargs, - ) - - return _completion_with_retry(**kwargs) - - -async def acompletion_with_retry( - llm: ChatFireworks, - use_retry: bool, - *, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, -) -> Any: - """Use tenacity to retry the async completion call.""" - import fireworks.client - - retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) - - @conditional_decorator(use_retry, retry_decorator) - async def _completion_with_retry(**kwargs: Any) -> Any: - return await fireworks.client.ChatCompletion.acreate( - **kwargs, - ) - - return await _completion_with_retry(**kwargs) - - -async def acompletion_with_retry_streaming( - llm: ChatFireworks, - use_retry: bool, - *, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, -) -> Any: - """Use tenacity to retry the completion call for streaming.""" - import fireworks.client - - retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) - - @conditional_decorator(use_retry, retry_decorator) - async def _completion_with_retry(**kwargs: Any) -> Any: - return fireworks.client.ChatCompletion.acreate( - **kwargs, - ) - - return await _completion_with_retry(**kwargs) - - -def _create_retry_decorator( - llm: ChatFireworks, - run_manager: Optional[ - Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] - ] = None, -) -> Callable[[Any], Any]: - """Define retry mechanism.""" - import fireworks.client - - errors = [ - fireworks.client.error.RateLimitError, - fireworks.client.error.InternalServerError, - fireworks.client.error.BadGatewayError, - fireworks.client.error.ServiceUnavailableError, - ] - return create_base_retry_decorator( - error_types=errors, max_retries=llm.max_retries, run_manager=run_manager - ) diff --git a/libs/community/langchain_community/chat_models/gigachat.py b/libs/community/langchain_community/chat_models/gigachat.py deleted file mode 100644 index beeed4940..000000000 --- a/libs/community/langchain_community/chat_models/gigachat.py +++ /dev/null @@ -1,280 +0,0 @@ -from __future__ import annotations - -import logging -from typing import ( - TYPE_CHECKING, - Any, - AsyncIterator, - Dict, - Iterator, - List, - Mapping, - Optional, - Type, -) - -from langchain_core._api.deprecation import deprecated -from langchain_core.callbacks import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain_core.language_models.chat_models import ( - BaseChatModel, - agenerate_from_stream, - generate_from_stream, -) -from langchain_core.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - BaseMessageChunk, - ChatMessage, - ChatMessageChunk, - FunctionMessage, - FunctionMessageChunk, - HumanMessage, - HumanMessageChunk, - SystemMessage, - SystemMessageChunk, -) -from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult - -from langchain_community.llms.gigachat import _BaseGigaChat - -if TYPE_CHECKING: - import gigachat.models as gm - -logger = logging.getLogger(__name__) - - -def _convert_dict_to_message(message: gm.Messages) -> BaseMessage: - from gigachat.models import FunctionCall, MessagesRole - - additional_kwargs: Dict = {} - if function_call := message.function_call: - if isinstance(function_call, FunctionCall): - additional_kwargs["function_call"] = dict(function_call) - elif isinstance(function_call, dict): - additional_kwargs["function_call"] = function_call - - if message.role == MessagesRole.SYSTEM: - return SystemMessage(content=message.content) - elif message.role == MessagesRole.USER: - return HumanMessage(content=message.content) - elif message.role == MessagesRole.ASSISTANT: - return AIMessage(content=message.content, additional_kwargs=additional_kwargs) - else: - raise TypeError(f"Got unknown role {message.role} {message}") - - -def _convert_message_to_dict(message: gm.BaseMessage) -> gm.Messages: - from gigachat.models import Messages, MessagesRole - - if isinstance(message, SystemMessage): - return Messages(role=MessagesRole.SYSTEM, content=message.content) - elif isinstance(message, HumanMessage): - return Messages(role=MessagesRole.USER, content=message.content) - elif isinstance(message, AIMessage): - return Messages( - role=MessagesRole.ASSISTANT, - content=message.content, - function_call=message.additional_kwargs.get("function_call", None), - ) - elif isinstance(message, ChatMessage): - return Messages(role=MessagesRole(message.role), content=message.content) - elif isinstance(message, FunctionMessage): - return Messages(role=MessagesRole.FUNCTION, content=message.content) - else: - raise TypeError(f"Got unknown type {message}") - - -def _convert_delta_to_message_chunk( - _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] -) -> BaseMessageChunk: - role = _dict.get("role") - content = _dict.get("content") or "" - additional_kwargs: Dict = {} - if _dict.get("function_call"): - function_call = dict(_dict["function_call"]) - if "name" in function_call and function_call["name"] is None: - function_call["name"] = "" - additional_kwargs["function_call"] = function_call - - if role == "user" or default_class == HumanMessageChunk: - return HumanMessageChunk(content=content) - elif role == "assistant" or default_class == AIMessageChunk: - return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) - elif role == "system" or default_class == SystemMessageChunk: - return SystemMessageChunk(content=content) - elif role == "function" or default_class == FunctionMessageChunk: - return FunctionMessageChunk(content=content, name=_dict["name"]) - elif role or default_class == ChatMessageChunk: - return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type] - else: - return default_class(content=content) # type: ignore[call-arg] - - -@deprecated( - since="0.3.5", - removal="1.0", - alternative_import="langchain_gigachat.GigaChat", -) -class GigaChat(_BaseGigaChat, BaseChatModel): - """`GigaChat` large language models API. - - To use, you should pass login and password to access GigaChat API or use token. - - Example: - .. code-block:: python - - from langchain_community.chat_models import GigaChat - giga = GigaChat(credentials=..., scope=..., verify_ssl_certs=...) - """ - - def _build_payload(self, messages: List[BaseMessage], **kwargs: Any) -> gm.Chat: - from gigachat.models import Chat - - payload = Chat( - messages=[_convert_message_to_dict(m) for m in messages], - ) - - payload.functions = kwargs.get("functions", None) - payload.model = self.model - - if self.profanity_check is not None: - payload.profanity_check = self.profanity_check - if self.temperature is not None: - payload.temperature = self.temperature - if self.top_p is not None: - payload.top_p = self.top_p - if self.max_tokens is not None: - payload.max_tokens = self.max_tokens - if self.repetition_penalty is not None: - payload.repetition_penalty = self.repetition_penalty - if self.update_interval is not None: - payload.update_interval = self.update_interval - - if self.verbose: - logger.warning("Giga request: %s", payload.dict()) - - return payload - - def _create_chat_result(self, response: Any) -> ChatResult: - generations = [] - for res in response.choices: - message = _convert_dict_to_message(res.message) - finish_reason = res.finish_reason - gen = ChatGeneration( - message=message, - generation_info={"finish_reason": finish_reason}, - ) - generations.append(gen) - if finish_reason != "stop": - logger.warning( - "Giga generation stopped with reason: %s", - finish_reason, - ) - if self.verbose: - logger.warning("Giga response: %s", message.content) - llm_output = {"token_usage": response.usage, "model_name": response.model} - return ChatResult(generations=generations, llm_output=llm_output) - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, - **kwargs: Any, - ) -> ChatResult: - should_stream = stream if stream is not None else self.streaming - if should_stream: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return generate_from_stream(stream_iter) - - payload = self._build_payload(messages, **kwargs) - response = self._client.chat(payload) - - return self._create_chat_result(response) - - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, - **kwargs: Any, - ) -> ChatResult: - should_stream = stream if stream is not None else self.streaming - if should_stream: - stream_iter = self._astream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return await agenerate_from_stream(stream_iter) - - payload = self._build_payload(messages, **kwargs) - response = await self._client.achat(payload) - - return self._create_chat_result(response) - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - payload = self._build_payload(messages, **kwargs) - - for chunk in self._client.stream(payload): - if not isinstance(chunk, dict): - chunk = chunk.dict() - if len(chunk["choices"]) == 0: - continue - - choice = chunk["choices"][0] - content = choice.get("delta", {}).get("content", {}) - chunk = _convert_delta_to_message_chunk(choice["delta"], AIMessageChunk) - - finish_reason = choice.get("finish_reason") - - generation_info = ( - dict(finish_reason=finish_reason) if finish_reason is not None else None - ) - - if run_manager: - run_manager.on_llm_new_token(content) - - yield ChatGenerationChunk(message=chunk, generation_info=generation_info) - - async def _astream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - payload = self._build_payload(messages, **kwargs) - - async for chunk in self._client.astream(payload): - if not isinstance(chunk, dict): - chunk = chunk.dict() - if len(chunk["choices"]) == 0: - continue - - choice = chunk["choices"][0] - content = choice.get("delta", {}).get("content", {}) - chunk = _convert_delta_to_message_chunk(choice["delta"], AIMessageChunk) - - finish_reason = choice.get("finish_reason") - - generation_info = ( - dict(finish_reason=finish_reason) if finish_reason is not None else None - ) - - if run_manager: - await run_manager.on_llm_new_token(content) - - yield ChatGenerationChunk(message=chunk, generation_info=generation_info) diff --git a/libs/community/langchain_community/chat_models/gpt_router.py b/libs/community/langchain_community/chat_models/gpt_router.py index 91f314a8a..82ea08feb 100644 --- a/libs/community/langchain_community/chat_models/gpt_router.py +++ b/libs/community/langchain_community/chat_models/gpt_router.py @@ -28,7 +28,16 @@ generate_from_stream, ) from langchain_core.language_models.llms import create_base_retry_decorator -from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk +from langchain_core.messages import ( + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessageChunk, + FunctionMessageChunk, + HumanMessageChunk, + SystemMessageChunk, + ToolMessageChunk, +) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from pydantic import BaseModel, Field, SecretStr, model_validator @@ -38,7 +47,6 @@ convert_dict_to_message, convert_message_to_dict, ) -from langchain_community.chat_models.openai import _convert_delta_to_message_chunk if TYPE_CHECKING: from gpt_router.models import ChunkedGenerationResponse, GenerationResponse @@ -49,6 +57,36 @@ DEFAULT_API_BASE_URL = "https://gpt-router-preview.writesonic.com" +def _convert_delta_to_message_chunk( + _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] +) -> BaseMessageChunk: + role = _dict.get("role") + content = _dict.get("content") or "" + additional_kwargs: Dict = {} + if _dict.get("function_call"): + function_call = dict(_dict["function_call"]) + if "name" in function_call and function_call["name"] is None: + function_call["name"] = "" + additional_kwargs["function_call"] = function_call + if _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = _dict["tool_calls"] + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content) + elif role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + elif role == "system" or default_class == SystemMessageChunk: + return SystemMessageChunk(content=content) + elif role == "function" or default_class == FunctionMessageChunk: + return FunctionMessageChunk(content=content, name=_dict["name"]) + elif role == "tool" or default_class == ToolMessageChunk: + return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"]) + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type] + else: + return default_class(content=content) # type: ignore[call-arg] + + class GPTRouterException(Exception): """Error with the `GPTRouter APIs`""" diff --git a/libs/community/langchain_community/chat_models/huggingface.py b/libs/community/langchain_community/chat_models/huggingface.py deleted file mode 100644 index fe029d6d6..000000000 --- a/libs/community/langchain_community/chat_models/huggingface.py +++ /dev/null @@ -1,235 +0,0 @@ -"""Hugging Face Chat Wrapper.""" - -from typing import Any, AsyncIterator, Iterator, List, Optional - -from langchain_core._api.deprecation import deprecated -from langchain_core.callbacks.manager import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain_core.language_models.chat_models import ( - BaseChatModel, - agenerate_from_stream, - generate_from_stream, -) -from langchain_core.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - HumanMessage, - SystemMessage, -) -from langchain_core.outputs import ( - ChatGeneration, - ChatGenerationChunk, - ChatResult, - LLMResult, -) -from pydantic import model_validator -from typing_extensions import Self - -from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint -from langchain_community.llms.huggingface_hub import HuggingFaceHub -from langchain_community.llms.huggingface_text_gen_inference import ( - HuggingFaceTextGenInference, -) - -DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant.""" - - -@deprecated( - since="0.0.37", - removal="1.0", - alternative_import="langchain_huggingface.ChatHuggingFace", -) -class ChatHuggingFace(BaseChatModel): - """ - Wrapper for using Hugging Face LLM's as ChatModels. - - Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`, - and `HuggingFaceHub` LLMs. - - Upon instantiating this class, the model_id is resolved from the url - provided to the LLM, and the appropriate tokenizer is loaded from - the HuggingFace Hub. - - Adapted from: https://python.langchain.com/docs/integrations/chat/llama2_chat - """ - - llm: Any - """LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint, or - HuggingFaceHub.""" - system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT) - tokenizer: Any = None - model_id: Optional[str] = None - streaming: bool = False - - def __init__(self, **kwargs: Any): - super().__init__(**kwargs) - - from transformers import AutoTokenizer - - self._resolve_model_id() - - self.tokenizer = ( - AutoTokenizer.from_pretrained(self.model_id) - if self.tokenizer is None - else self.tokenizer - ) - - @model_validator(mode="after") - def validate_llm(self) -> Self: - if not isinstance( - self.llm, - (HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub), - ): - raise TypeError( - "Expected llm to be one of HuggingFaceTextGenInference, " - f"HuggingFaceEndpoint, HuggingFaceHub, received {type(self.llm)}" - ) - return self - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - request = self._to_chat_prompt(messages) - - for data in self.llm.stream(request, **kwargs): - delta = data - chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) - if run_manager: - run_manager.on_llm_new_token(delta, chunk=chunk) - yield chunk - - async def _astream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - request = self._to_chat_prompt(messages) - async for data in self.llm.astream(request, **kwargs): - delta = data - chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) - if run_manager: - await run_manager.on_llm_new_token(delta, chunk=chunk) - yield chunk - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - if self.streaming: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return generate_from_stream(stream_iter) - - llm_input = self._to_chat_prompt(messages) - llm_result = self.llm._generate( - prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs - ) - return self._to_chat_result(llm_result) - - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - if self.streaming: - stream_iter = self._astream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return await agenerate_from_stream(stream_iter) - - llm_input = self._to_chat_prompt(messages) - llm_result = await self.llm._agenerate( - prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs - ) - return self._to_chat_result(llm_result) - - def _to_chat_prompt( - self, - messages: List[BaseMessage], - ) -> str: - """Convert a list of messages into a prompt format expected by wrapped LLM.""" - if not messages: - raise ValueError("At least one HumanMessage must be provided!") - - if not isinstance(messages[-1], HumanMessage): - raise ValueError("Last message must be a HumanMessage!") - - messages_dicts = [self._to_chatml_format(m) for m in messages] - - return self.tokenizer.apply_chat_template( - messages_dicts, tokenize=False, add_generation_prompt=True - ) - - def _to_chatml_format(self, message: BaseMessage) -> dict: - """Convert LangChain message to ChatML format.""" - - if isinstance(message, SystemMessage): - role = "system" - elif isinstance(message, AIMessage): - role = "assistant" - elif isinstance(message, HumanMessage): - role = "user" - else: - raise ValueError(f"Unknown message type: {type(message)}") - - return {"role": role, "content": message.content} - - @staticmethod - def _to_chat_result(llm_result: LLMResult) -> ChatResult: - chat_generations = [] - - for g in llm_result.generations[0]: - chat_generation = ChatGeneration( - message=AIMessage(content=g.text), generation_info=g.generation_info - ) - chat_generations.append(chat_generation) - - return ChatResult( - generations=chat_generations, llm_output=llm_result.llm_output - ) - - def _resolve_model_id(self) -> None: - """Resolve the model_id from the LLM's inference_server_url""" - - from huggingface_hub import list_inference_endpoints - - available_endpoints = list_inference_endpoints("*") - if isinstance(self.llm, HuggingFaceHub) or ( - hasattr(self.llm, "repo_id") and self.llm.repo_id - ): - self.model_id = self.llm.repo_id - return - elif isinstance(self.llm, HuggingFaceTextGenInference): - endpoint_url: Optional[str] = self.llm.inference_server_url - else: - endpoint_url = self.llm.endpoint_url - - for endpoint in available_endpoints: - if endpoint.url == endpoint_url: - self.model_id = endpoint.repository - - if not self.model_id: - raise ValueError( - "Failed to resolve model_id:" - f"Could not find model id for inference server: {endpoint_url}" - "Make sure that your Hugging Face token has access to the endpoint." - ) - - @property - def _llm_type(self) -> str: - return "huggingface-chat-wrapper" diff --git a/libs/community/langchain_community/chat_models/konko.py b/libs/community/langchain_community/chat_models/konko.py index db4f1eded..eb9aafe8e 100644 --- a/libs/community/langchain_community/chat_models/konko.py +++ b/libs/community/langchain_community/chat_models/konko.py @@ -10,6 +10,7 @@ Dict, Iterator, List, + Mapping, Optional, Set, Tuple, @@ -22,7 +23,19 @@ from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) -from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk +from langchain_core.language_models.chat_models import ( + generate_from_stream, +) +from langchain_core.messages import ( + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessageChunk, + FunctionMessageChunk, + HumanMessageChunk, + SystemMessageChunk, + ToolMessageChunk, +) from langchain_core.outputs import ChatGenerationChunk, ChatResult from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from pydantic import Field, SecretStr @@ -32,8 +45,6 @@ ) from langchain_community.chat_models.openai import ( ChatOpenAI, - _convert_delta_to_message_chunk, - generate_from_stream, ) from langchain_community.utils.openai import is_openai_v1 @@ -43,6 +54,36 @@ logger = logging.getLogger(__name__) +def _convert_delta_to_message_chunk( + _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] +) -> BaseMessageChunk: + role = _dict.get("role") + content = _dict.get("content") or "" + additional_kwargs: Dict = {} + if _dict.get("function_call"): + function_call = dict(_dict["function_call"]) + if "name" in function_call and function_call["name"] is None: + function_call["name"] = "" + additional_kwargs["function_call"] = function_call + if _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = _dict["tool_calls"] + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content) + elif role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + elif role == "system" or default_class == SystemMessageChunk: + return SystemMessageChunk(content=content) + elif role == "function" or default_class == FunctionMessageChunk: + return FunctionMessageChunk(content=content, name=_dict["name"]) + elif role == "tool" or default_class == ToolMessageChunk: + return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"]) + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type] + else: + return default_class(content=content) # type: ignore[call-arg] + + class ChatKonko(ChatOpenAI): """`ChatKonko` Chat large language models API. diff --git a/libs/community/langchain_community/chat_models/litellm.py b/libs/community/langchain_community/chat_models/litellm.py deleted file mode 100644 index da107b3c9..000000000 --- a/libs/community/langchain_community/chat_models/litellm.py +++ /dev/null @@ -1,632 +0,0 @@ -""" -Deprecated LiteLLM wrapper. - -⭐ Use `pip install langchain-litellm` and import - `from langchain_litellm import ChatLiteLLM` instead. -""" - -from __future__ import annotations - -import json -import logging -from typing import ( - Any, - AsyncIterator, - Callable, - Dict, - Iterator, - List, - Literal, - Mapping, - Optional, - Sequence, - Tuple, - Type, - Union, -) - -from langchain_core._api.deprecation import deprecated -from langchain_core.callbacks import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain_core.language_models import LanguageModelInput -from langchain_core.language_models.chat_models import ( - BaseChatModel, - agenerate_from_stream, - generate_from_stream, -) -from langchain_core.language_models.llms import create_base_retry_decorator -from langchain_core.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - BaseMessageChunk, - ChatMessage, - ChatMessageChunk, - FunctionMessage, - FunctionMessageChunk, - HumanMessage, - HumanMessageChunk, - SystemMessage, - SystemMessageChunk, - ToolCall, - ToolCallChunk, - ToolMessage, -) -from langchain_core.messages.ai import UsageMetadata -from langchain_core.outputs import ( - ChatGeneration, - ChatGenerationChunk, - ChatResult, -) -from langchain_core.runnables import Runnable -from langchain_core.tools import BaseTool -from langchain_core.utils import get_from_dict_or_env, pre_init -from langchain_core.utils.function_calling import convert_to_openai_tool -from pydantic import BaseModel, Field - -logger = logging.getLogger(__name__) - - -class ChatLiteLLMException(Exception): - """Error with the `LiteLLM I/O` library""" - - -def _create_retry_decorator( - llm: ChatLiteLLM, - run_manager: Optional[ - Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] - ] = None, -) -> Callable[[Any], Any]: - """Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions""" - import litellm - - errors = [ - litellm.Timeout, - litellm.APIError, - litellm.APIConnectionError, - litellm.RateLimitError, - ] - return create_base_retry_decorator( - error_types=errors, max_retries=llm.max_retries, run_manager=run_manager - ) - - -def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: - role = _dict["role"] - if role == "user": - return HumanMessage(content=_dict["content"]) - elif role == "assistant": - # Fix for azure - # Also OpenAI returns None for tool invocations - content = _dict.get("content", "") or "" - - additional_kwargs = {} - if _dict.get("function_call"): - additional_kwargs["function_call"] = dict(_dict["function_call"]) - - if _dict.get("tool_calls"): - additional_kwargs["tool_calls"] = _dict["tool_calls"] - - return AIMessage(content=content, additional_kwargs=additional_kwargs) - elif role == "system": - return SystemMessage(content=_dict["content"]) - elif role == "function": - return FunctionMessage(content=_dict["content"], name=_dict["name"]) - else: - return ChatMessage(content=_dict["content"], role=role) - - -async def acompletion_with_retry( - llm: ChatLiteLLM, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, -) -> Any: - """Use tenacity to retry the async completion call.""" - retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) - - @retry_decorator - async def _completion_with_retry(**kwargs: Any) -> Any: - # Use OpenAI's async api https://github.com/openai/openai-python#async-api - return await llm.client.acreate(**kwargs) - - return await _completion_with_retry(**kwargs) - - -def _convert_delta_to_message_chunk( - _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] -) -> BaseMessageChunk: - role = _dict.get("role") - content = _dict.get("content") or "" - if _dict.get("function_call"): - additional_kwargs = {"function_call": dict(_dict["function_call"])} - elif _dict.get("reasoning_content"): - additional_kwargs = {"reasoning_content": _dict["reasoning_content"]} - else: - additional_kwargs = {} - - tool_call_chunks = [] - if raw_tool_calls := _dict.get("tool_calls"): - additional_kwargs["tool_calls"] = raw_tool_calls - try: - tool_call_chunks = [ - ToolCallChunk( - name=rtc["function"].get("name"), - args=rtc["function"].get("arguments"), - id=rtc.get("id"), - index=rtc["index"], - ) - for rtc in raw_tool_calls - ] - except KeyError: - pass - - if role == "user" or default_class == HumanMessageChunk: - return HumanMessageChunk(content=content) - elif role == "assistant" or default_class == AIMessageChunk: - return AIMessageChunk( - content=content, - additional_kwargs=additional_kwargs, - tool_call_chunks=tool_call_chunks, - ) - elif role == "system" or default_class == SystemMessageChunk: - return SystemMessageChunk(content=content) - elif role == "function" or default_class == FunctionMessageChunk: - return FunctionMessageChunk(content=content, name=_dict["name"]) - elif role or default_class == ChatMessageChunk: - return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type] - else: - return default_class(content=content) # type: ignore[call-arg] - - -def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict: - return { - "type": "function", - "id": tool_call["id"], - "function": { - "name": tool_call["name"], - "arguments": json.dumps(tool_call["args"]), - }, - } - - -def _convert_message_to_dict(message: BaseMessage) -> dict: - message_dict: Dict[str, Any] = {"content": message.content} - if isinstance(message, ChatMessage): - message_dict["role"] = message.role - elif isinstance(message, HumanMessage): - message_dict["role"] = "user" - elif isinstance(message, AIMessage): - message_dict["role"] = "assistant" - if "function_call" in message.additional_kwargs: - message_dict["function_call"] = message.additional_kwargs["function_call"] - if message.tool_calls: - message_dict["tool_calls"] = [ - _lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls - ] - elif "tool_calls" in message.additional_kwargs: - message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] - elif isinstance(message, SystemMessage): - message_dict["role"] = "system" - elif isinstance(message, FunctionMessage): - message_dict["role"] = "function" - message_dict["name"] = message.name - elif isinstance(message, ToolMessage): - message_dict["role"] = "tool" - message_dict["tool_call_id"] = message.tool_call_id - else: - raise ValueError(f"Got unknown type {message}") - if "name" in message.additional_kwargs: - message_dict["name"] = message.additional_kwargs["name"] - return message_dict - - -_OPENAI_MODELS = [ - "o1-mini", - "o1-preview", - "gpt-4o-mini", - "gpt-4o-mini-2024-07-18", - "gpt-4o", - "gpt-4o-2024-08-06", - "gpt-4o-2024-05-13", - "gpt-4-turbo", - "gpt-4-turbo-preview", - "gpt-4-0125-preview", - "gpt-4-1106-preview", - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo", - "gpt-3.5-turbo-0301", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-16k-0613", - "gpt-4", - "gpt-4-0314", - "gpt-4-0613", - "gpt-4-32k", - "gpt-4-32k-0314", - "gpt-4-32k-0613", -] - - -@deprecated( - since="0.3.24", - removal="1.0", - alternative_import="langchain_litellm.ChatLiteLLM", -) -class ChatLiteLLM(BaseChatModel): - """DEPRECATED – use `langchain_litellm.ChatLiteLLM` instead.""" - - client: Any = None #: :meta private: - model: str = "gpt-3.5-turbo" - model_name: Optional[str] = None - """Model name to use.""" - openai_api_key: Optional[str] = None - azure_api_key: Optional[str] = None - anthropic_api_key: Optional[str] = None - replicate_api_key: Optional[str] = None - cohere_api_key: Optional[str] = None - openrouter_api_key: Optional[str] = None - api_key: Optional[str] = None - streaming: bool = False - api_base: Optional[str] = None - organization: Optional[str] = None - custom_llm_provider: Optional[str] = None - request_timeout: Optional[Union[float, Tuple[float, float]]] = None - temperature: Optional[float] = None - """Run inference with this temperature. Must be in the closed - interval [0.0, 1.0].""" - model_kwargs: Dict[str, Any] = Field(default_factory=dict) - """Holds any model parameters valid for API call not explicitly specified.""" - top_p: Optional[float] = None - """Decode using nucleus sampling: consider the smallest set of tokens whose - probability sum is at least top_p. Must be in the closed interval [0.0, 1.0].""" - top_k: Optional[int] = None - """Decode using top-k sampling: consider the set of top_k most probable tokens. - Must be positive.""" - n: Optional[int] = None - """Number of chat completions to generate for each prompt. Note that the API may - not return the full n completions if duplicates are generated.""" - max_tokens: Optional[int] = None - - max_retries: int = 1 - - @property - def _default_params(self) -> Dict[str, Any]: - """Get the default parameters for calling OpenAI API.""" - set_model_value = self.model - if self.model_name is not None: - set_model_value = self.model_name - return { - "model": set_model_value, - "force_timeout": self.request_timeout, - "max_tokens": self.max_tokens, - "stream": self.streaming, - "n": self.n, - "temperature": self.temperature, - "custom_llm_provider": self.custom_llm_provider, - **self.model_kwargs, - } - - @property - def _client_params(self) -> Dict[str, Any]: - """Get the parameters used for the openai client.""" - set_model_value = self.model - if self.model_name is not None: - set_model_value = self.model_name - self.client.api_base = self.api_base - self.client.api_key = self.api_key - for named_api_key in [ - "openai_api_key", - "azure_api_key", - "anthropic_api_key", - "replicate_api_key", - "cohere_api_key", - "openrouter_api_key", - ]: - if api_key_value := getattr(self, named_api_key): - setattr( - self.client, - named_api_key.replace("_api_key", "_key"), - api_key_value, - ) - self.client.organization = self.organization - creds: Dict[str, Any] = { - "model": set_model_value, - "force_timeout": self.request_timeout, - "api_base": self.api_base, - } - return {**self._default_params, **creds} - - def completion_with_retry( - self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any - ) -> Any: - """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(self, run_manager=run_manager) - - @retry_decorator - def _completion_with_retry(**kwargs: Any) -> Any: - return self.client.completion(**kwargs) - - return _completion_with_retry(**kwargs) - - @pre_init - def validate_environment(cls, values: Dict) -> Dict: - """Validate api key, python package exists, temperature, top_p, and top_k.""" - try: - import litellm - except ImportError: - raise ChatLiteLLMException( - "Could not import litellm python package. " - "Please install it with `pip install litellm`" - ) - - values["openai_api_key"] = get_from_dict_or_env( - values, "openai_api_key", "OPENAI_API_KEY", default="" - ) - values["azure_api_key"] = get_from_dict_or_env( - values, "azure_api_key", "AZURE_API_KEY", default="" - ) - values["anthropic_api_key"] = get_from_dict_or_env( - values, "anthropic_api_key", "ANTHROPIC_API_KEY", default="" - ) - values["replicate_api_key"] = get_from_dict_or_env( - values, "replicate_api_key", "REPLICATE_API_KEY", default="" - ) - values["openrouter_api_key"] = get_from_dict_or_env( - values, "openrouter_api_key", "OPENROUTER_API_KEY", default="" - ) - values["cohere_api_key"] = get_from_dict_or_env( - values, "cohere_api_key", "COHERE_API_KEY", default="" - ) - values["huggingface_api_key"] = get_from_dict_or_env( - values, "huggingface_api_key", "HUGGINGFACE_API_KEY", default="" - ) - values["together_ai_api_key"] = get_from_dict_or_env( - values, "together_ai_api_key", "TOGETHERAI_API_KEY", default="" - ) - values["client"] = litellm - - if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: - raise ValueError("temperature must be in the range [0.0, 1.0]") - - if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: - raise ValueError("top_p must be in the range [0.0, 1.0]") - - if values["top_k"] is not None and values["top_k"] <= 0: - raise ValueError("top_k must be positive") - - return values - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, - **kwargs: Any, - ) -> ChatResult: - should_stream = stream if stream is not None else self.streaming - if should_stream: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return generate_from_stream(stream_iter) - - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs} - response = self.completion_with_retry( - messages=message_dicts, run_manager=run_manager, **params - ) - return self._create_chat_result(response) - - def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: - generations = [] - token_usage = response.get("usage", {}) - for res in response["choices"]: - message = _convert_dict_to_message(res["message"]) - if isinstance(message, AIMessage): - message.response_metadata = { - "model_name": self.model_name or self.model - } - message.usage_metadata = _create_usage_metadata(token_usage) - gen = ChatGeneration( - message=message, - generation_info=dict(finish_reason=res.get("finish_reason")), - ) - generations.append(gen) - set_model_value = self.model - if self.model_name is not None: - set_model_value = self.model_name - llm_output = {"token_usage": token_usage, "model": set_model_value} - return ChatResult(generations=generations, llm_output=llm_output) - - def _create_message_dicts( - self, messages: List[BaseMessage], stop: Optional[List[str]] - ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: - params = self._client_params - if stop is not None: - if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") - params["stop"] = stop - message_dicts = [_convert_message_to_dict(m) for m in messages] - return message_dicts, params - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs, "stream": True} - - default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk - added_model_name = False - for chunk in self.completion_with_retry( - messages=message_dicts, run_manager=run_manager, **params - ): - if not isinstance(chunk, dict): - chunk = chunk.model_dump() - if len(chunk["choices"]) == 0: - continue - delta = chunk["choices"][0]["delta"] - usage = chunk.get("usage", {}) - chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) - if isinstance(chunk, AIMessageChunk): - if not added_model_name: - chunk.response_metadata = { - "model_name": self.model_name or self.model - } - added_model_name = True - chunk.usage_metadata = _create_usage_metadata(usage) - default_chunk_class = chunk.__class__ - cg_chunk = ChatGenerationChunk(message=chunk) - if run_manager: - run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk) - yield cg_chunk - - async def _astream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs, "stream": True} - - default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk - added_model_name = False - async for chunk in await acompletion_with_retry( - self, messages=message_dicts, run_manager=run_manager, **params - ): - if not isinstance(chunk, dict): - chunk = chunk.model_dump() - if len(chunk["choices"]) == 0: - continue - delta = chunk["choices"][0]["delta"] - usage = chunk.get("usage", {}) - chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) - if isinstance(chunk, AIMessageChunk): - if not added_model_name: - chunk.response_metadata = { - "model_name": self.model_name or self.model - } - added_model_name = True - chunk.usage_metadata = _create_usage_metadata(usage) - default_chunk_class = chunk.__class__ - cg_chunk = ChatGenerationChunk(message=chunk) - if run_manager: - await run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk) - yield cg_chunk - - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, - **kwargs: Any, - ) -> ChatResult: - should_stream = stream if stream is not None else self.streaming - if should_stream: - stream_iter = self._astream( - messages=messages, stop=stop, run_manager=run_manager, **kwargs - ) - return await agenerate_from_stream(stream_iter) - - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs} - response = await acompletion_with_retry( - self, messages=message_dicts, run_manager=run_manager, **params - ) - return self._create_chat_result(response) - - def bind_tools( - self, - tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], - tool_choice: Optional[ - Union[dict, str, Literal["auto", "none", "required", "any"], bool] - ] = None, - **kwargs: Any, - ) -> Runnable[LanguageModelInput, AIMessage]: - """Bind tool-like objects to this chat model. - - LiteLLM expects tools argument in OpenAI format. - - Args: - tools: A list of tool definitions to bind to this chat model. - Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic - models, callables, and BaseTools will be automatically converted to - their schema dictionary representation. - tool_choice: Which tool to require the model to call. Options are: - - str of the form ``"<>"``: calls <> tool. - - ``"auto"``: - automatically selects a tool (including no tool). - - ``"none"``: - does not call a tool. - - ``"any"`` or ``"required"`` or ``True``: - forces least one tool to be called. - - dict of the form: - ``{"type": "function", "function": {"name": <>}}`` - - ``False`` or ``None``: no effect - **kwargs: Any additional parameters to pass to the - :class:`~langchain.runnable.Runnable` constructor. - """ - - formatted_tools = [convert_to_openai_tool(tool) for tool in tools] - - # In case of openai if tool_choice is `any` or if bool has been provided we - # change it to `required` as that is supported by openai. - if ( - (self.model is not None and "azure" in self.model) - or (self.model_name is not None and "azure" in self.model_name) - or (self.model is not None and self.model in _OPENAI_MODELS) - or (self.model_name is not None and self.model_name in _OPENAI_MODELS) - ) and (tool_choice == "any" or isinstance(tool_choice, bool)): - tool_choice = "required" - # If tool_choice is bool apart from openai we make it `any` - elif isinstance(tool_choice, bool): - tool_choice = "any" - elif isinstance(tool_choice, dict): - tool_names = [ - formatted_tool["function"]["name"] for formatted_tool in formatted_tools - ] - if not any( - tool_name == tool_choice["function"]["name"] for tool_name in tool_names - ): - raise ValueError( - f"Tool choice {tool_choice} was specified, but the only " - f"provided tools were {tool_names}." - ) - return super().bind(tools=formatted_tools, tool_choice=tool_choice, **kwargs) - - @property - def _identifying_params(self) -> Dict[str, Any]: - """Get the identifying parameters.""" - set_model_value = self.model - if self.model_name is not None: - set_model_value = self.model_name - return { - "model": set_model_value, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "n": self.n, - } - - @property - def _llm_type(self) -> str: - return "litellm-chat" - - -def _create_usage_metadata(token_usage: Mapping[str, Any]) -> UsageMetadata: - input_tokens = token_usage.get("prompt_tokens", 0) - output_tokens = token_usage.get("completion_tokens", 0) - return UsageMetadata( - input_tokens=input_tokens, - output_tokens=output_tokens, - total_tokens=input_tokens + output_tokens, - ) diff --git a/libs/community/langchain_community/chat_models/litellm_router.py b/libs/community/langchain_community/chat_models/litellm_router.py deleted file mode 100644 index 4fce0d59c..000000000 --- a/libs/community/langchain_community/chat_models/litellm_router.py +++ /dev/null @@ -1,230 +0,0 @@ -""" -Deprecated LiteLLM wrapper. - -⭐ Use `pip install langchain-litellm` and import - `from langchain_litellm import ChatLiteLLMRouter` instead. -""" - -from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional, Type - -from langchain_core._api.deprecation import deprecated -from langchain_core.callbacks.manager import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain_core.language_models.chat_models import ( - agenerate_from_stream, - generate_from_stream, -) -from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk -from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult - -from langchain_community.chat_models.litellm import ( - ChatLiteLLM, - _convert_delta_to_message_chunk, - _convert_dict_to_message, -) - -token_usage_key_name = "token_usage" # nosec # incorrectly flagged as password -model_extra_key_name = "model_extra" # nosec # incorrectly flagged as password - - -def get_llm_output(usage: Any, **params: Any) -> dict: - """Get llm output from usage and params.""" - llm_output = {token_usage_key_name: usage} - # copy over metadata (metadata came from router completion call) - metadata = params["metadata"] - for key in metadata: - if key not in llm_output: - # if token usage in metadata, prefer metadata's copy of it - llm_output[key] = metadata[key] - return llm_output - - -@deprecated( - since="0.3.24", - removal="1.0", - alternative_import="langchain_litellm.ChatLiteLLMRouter", -) -class ChatLiteLLMRouter(ChatLiteLLM): - """DEPRECATED – use `langchain_litellm.ChatLiteLLMRouter` instead.""" - - router: Any - - def __init__(self, *, router: Any, **kwargs: Any) -> None: - """Construct Chat LiteLLM Router.""" - super().__init__(router=router, **kwargs) # type: ignore[call-arg] - self.router = router - - @property - def _llm_type(self) -> str: - return "LiteLLMRouter" - - def _prepare_params_for_router(self, params: Any) -> None: - # allow the router to set api_base based on its model choice - api_base_key_name = "api_base" - if api_base_key_name in params and params[api_base_key_name] is None: - del params[api_base_key_name] - - # add metadata so router can fill it below - params.setdefault("metadata", {}) - - def set_default_model(self, model_name: str) -> None: - """Set the default model to use for completion calls. - - Sets `self.model` to `model_name` if it is in the litellm router's - (`self.router`) model list. This provides the default model to use - for completion calls if no `model` kwarg is provided. - """ - model_list = self.router.model_list - if not model_list: - raise ValueError("model_list is None or empty.") - for entry in model_list: - if entry["model_name"] == model_name: - self.model = model_name - return - raise ValueError(f"Model {model_name} not found in model_list.") - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, - **kwargs: Any, - ) -> ChatResult: - should_stream = stream if stream is not None else self.streaming - if should_stream: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return generate_from_stream(stream_iter) - - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs} - self._prepare_params_for_router(params) - - response = self.router.completion( - messages=message_dicts, - **params, - ) - return self._create_chat_result(response, **params) - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs, "stream": True} - self._prepare_params_for_router(params) - - for chunk in self.router.completion(messages=message_dicts, **params): - if len(chunk["choices"]) == 0: - continue - delta = chunk["choices"][0]["delta"] - chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) - default_chunk_class = chunk.__class__ - cg_chunk = ChatGenerationChunk(message=chunk) - if run_manager: - run_manager.on_llm_new_token( - str(chunk.content), chunk=cg_chunk, **params - ) - yield cg_chunk - - async def _astream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs, "stream": True} - self._prepare_params_for_router(params) - - async for chunk in await self.router.acompletion( - messages=message_dicts, **params - ): - if len(chunk["choices"]) == 0: - continue - delta = chunk["choices"][0]["delta"] - chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) - default_chunk_class = chunk.__class__ - cg_chunk = ChatGenerationChunk(message=chunk) - if run_manager: - await run_manager.on_llm_new_token( - str(chunk.content), chunk=cg_chunk, **params - ) - yield cg_chunk - - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, - **kwargs: Any, - ) -> ChatResult: - should_stream = stream if stream is not None else self.streaming - if should_stream: - stream_iter = self._astream( - messages=messages, stop=stop, run_manager=run_manager, **kwargs - ) - return await agenerate_from_stream(stream_iter) - - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs} - self._prepare_params_for_router(params) - - response = await self.router.acompletion( - messages=message_dicts, - **params, - ) - return self._create_chat_result(response, **params) - - # from - # https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/chat_models/openai.py - # but modified to handle LiteLLM Usage class - def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: - overall_token_usage: dict = {} - system_fingerprint = None - for output in llm_outputs: - if output is None: - # Happens in streaming - continue - token_usage = output["token_usage"] - if token_usage is not None: - # get dict from LiteLLM Usage class - for k, v in token_usage.model_dump().items(): - if k in overall_token_usage and overall_token_usage[k] is not None: - overall_token_usage[k] += v - else: - overall_token_usage[k] = v - if system_fingerprint is None: - system_fingerprint = output.get("system_fingerprint") - combined = {"token_usage": overall_token_usage, "model_name": self.model} - if system_fingerprint: - combined["system_fingerprint"] = system_fingerprint - return combined - - def _create_chat_result( - self, response: Mapping[str, Any], **params: Any - ) -> ChatResult: - from litellm.utils import Usage - - generations = [] - for res in response["choices"]: - message = _convert_dict_to_message(res["message"]) - gen = ChatGeneration( - message=message, - generation_info=dict(finish_reason=res.get("finish_reason")), - ) - generations.append(gen) - token_usage = response.get("usage", Usage(prompt_tokens=0, total_tokens=0)) - llm_output = get_llm_output(token_usage, **params) - return ChatResult(generations=generations, llm_output=llm_output) diff --git a/libs/community/langchain_community/chat_models/moonshot.py b/libs/community/langchain_community/chat_models/moonshot.py index 6d31426fd..7f2eff555 100644 --- a/libs/community/langchain_community/chat_models/moonshot.py +++ b/libs/community/langchain_community/chat_models/moonshot.py @@ -8,7 +8,7 @@ pre_init, ) -from langchain_community.chat_models import ChatOpenAI +from langchain_community.chat_models.openai import ChatOpenAI from langchain_community.llms.moonshot import MOONSHOT_SERVICE_URL_BASE, MoonshotCommon diff --git a/libs/community/langchain_community/chat_models/octoai.py b/libs/community/langchain_community/chat_models/octoai.py index e2d6b9275..5cdb066ae 100644 --- a/libs/community/langchain_community/chat_models/octoai.py +++ b/libs/community/langchain_community/chat_models/octoai.py @@ -46,7 +46,7 @@ class ChatOctoAI(ChatOpenAI): """ octoai_api_base: str = Field(default=DEFAULT_API_BASE) - octoai_api_token: SecretStr = Field(default=SecretStr(""), alias="api_key") + octoai_api_token: SecretStr = Field(default=SecretStr("")) model_name: str = Field(default=DEFAULT_MODEL, alias="model") @property diff --git a/libs/community/langchain_community/chat_models/ollama.py b/libs/community/langchain_community/chat_models/ollama.py deleted file mode 100644 index 0a3788b48..000000000 --- a/libs/community/langchain_community/chat_models/ollama.py +++ /dev/null @@ -1,398 +0,0 @@ -import json -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast - -from langchain_core._api import deprecated -from langchain_core.callbacks import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams -from langchain_core.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - ChatMessage, - HumanMessage, - SystemMessage, -) -from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult - -from langchain_community.llms.ollama import OllamaEndpointNotFoundError, _OllamaCommon - - -@deprecated("0.0.3", alternative="_chat_stream_response_to_chat_generation_chunk") -def _stream_response_to_chat_generation_chunk( - stream_response: str, -) -> ChatGenerationChunk: - """Convert a stream response to a generation chunk.""" - parsed_response = json.loads(stream_response) - generation_info = parsed_response if parsed_response.get("done") is True else None - return ChatGenerationChunk( - message=AIMessageChunk(content=parsed_response.get("response", "")), - generation_info=generation_info, - ) - - -def _chat_stream_response_to_chat_generation_chunk( - stream_response: str, -) -> ChatGenerationChunk: - """Convert a stream response to a generation chunk.""" - parsed_response = json.loads(stream_response) - generation_info = parsed_response if parsed_response.get("done") is True else None - return ChatGenerationChunk( - message=AIMessageChunk( - content=parsed_response.get("message", {}).get("content", "") - ), - generation_info=generation_info, - ) - - -@deprecated( - since="0.3.1", - removal="1.0.0", - alternative_import="langchain_ollama.ChatOllama", -) -class ChatOllama(BaseChatModel, _OllamaCommon): - """Ollama locally runs large language models. - - To use, follow the instructions at https://ollama.ai/. - - Example: - .. code-block:: python - - from langchain_community.chat_models import ChatOllama - ollama = ChatOllama(model="llama2") - """ - - @property - def _llm_type(self) -> str: - """Return type of chat model.""" - return "ollama-chat" - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether this model can be serialized by Langchain.""" - return False - - def _get_ls_params( - self, stop: Optional[List[str]] = None, **kwargs: Any - ) -> LangSmithParams: - """Get standard params for tracing.""" - params = self._get_invocation_params(stop=stop, **kwargs) - ls_params = LangSmithParams( - ls_provider="ollama", - ls_model_name=self.model, - ls_model_type="chat", - ls_temperature=params.get("temperature", self.temperature), - ) - if ls_max_tokens := params.get("num_predict", self.num_predict): - ls_params["ls_max_tokens"] = ls_max_tokens - if ls_stop := stop or params.get("stop", None) or self.stop: - ls_params["ls_stop"] = ls_stop - return ls_params - - @deprecated("0.0.3", alternative="_convert_messages_to_ollama_messages") - def _format_message_as_text(self, message: BaseMessage) -> str: - if isinstance(message, ChatMessage): - message_text = f"\n\n{message.role.capitalize()}: {message.content}" - elif isinstance(message, HumanMessage): - if isinstance(message.content, List): - first_content = cast(List[Dict], message.content)[0] - content_type = first_content.get("type") - if content_type == "text": - message_text = f"[INST] {first_content['text']} [/INST]" - elif content_type == "image_url": - message_text = first_content["image_url"]["url"] - else: - message_text = f"[INST] {message.content} [/INST]" - elif isinstance(message, AIMessage): - message_text = f"{message.content}" - elif isinstance(message, SystemMessage): - message_text = f"<> {message.content} <>" - else: - raise ValueError(f"Got unknown type {message}") - return message_text - - def _format_messages_as_text(self, messages: List[BaseMessage]) -> str: - return "\n".join( - [self._format_message_as_text(message) for message in messages] - ) - - def _convert_messages_to_ollama_messages( - self, messages: List[BaseMessage] - ) -> List[Dict[str, Union[str, List[str]]]]: - ollama_messages: List = [] - for message in messages: - role = "" - if isinstance(message, HumanMessage): - role = "user" - elif isinstance(message, AIMessage): - role = "assistant" - elif isinstance(message, SystemMessage): - role = "system" - else: - raise ValueError("Received unsupported message type for Ollama.") - - content = "" - images = [] - if isinstance(message.content, str): - content = message.content - else: - for content_part in cast(List[Dict], message.content): - if content_part.get("type") == "text": - content += f"\n{content_part['text']}" - elif content_part.get("type") == "image_url": - image_url = None - temp_image_url = content_part.get("image_url") - if isinstance(temp_image_url, str): - image_url = content_part["image_url"] - elif ( - isinstance(temp_image_url, dict) and "url" in temp_image_url - ): - image_url = temp_image_url["url"] - else: - raise ValueError( - "Only string image_url or dict with string 'url' " - "inside content parts are supported." - ) - - image_url_components = image_url.split(",") - # Support data:image/jpeg;base64, format - # and base64 strings - if len(image_url_components) > 1: - images.append(image_url_components[1]) - else: - images.append(image_url_components[0]) - - else: - raise ValueError( - "Unsupported message content type. " - "Must either have type 'text' or type 'image_url' " - "with a string 'image_url' field." - ) - - ollama_messages.append( - { - "role": role, - "content": content, - "images": images, - } - ) - - return ollama_messages - - def _create_chat_stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> Iterator[str]: - payload = { - "model": self.model, - "messages": self._convert_messages_to_ollama_messages(messages), - } - yield from self._create_stream( - payload=payload, stop=stop, api_url=f"{self.base_url}/api/chat", **kwargs - ) - - async def _acreate_chat_stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> AsyncIterator[str]: - payload = { - "model": self.model, - "messages": self._convert_messages_to_ollama_messages(messages), - } - async for stream_resp in self._acreate_stream( - payload=payload, stop=stop, api_url=f"{self.base_url}/api/chat", **kwargs - ): - yield stream_resp - - def _chat_stream_with_aggregation( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - verbose: bool = False, - **kwargs: Any, - ) -> ChatGenerationChunk: - final_chunk: Optional[ChatGenerationChunk] = None - for stream_resp in self._create_chat_stream(messages, stop, **kwargs): - if stream_resp: - chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp) - if final_chunk is None: - final_chunk = chunk - else: - final_chunk += chunk - if run_manager: - run_manager.on_llm_new_token( - chunk.text, - chunk=chunk, - verbose=verbose, - ) - if final_chunk is None: - raise ValueError("No data received from Ollama stream.") - - return final_chunk - - async def _achat_stream_with_aggregation( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - verbose: bool = False, - **kwargs: Any, - ) -> ChatGenerationChunk: - final_chunk: Optional[ChatGenerationChunk] = None - async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs): - if stream_resp: - chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp) - if final_chunk is None: - final_chunk = chunk - else: - final_chunk += chunk - if run_manager: - await run_manager.on_llm_new_token( - chunk.text, - chunk=chunk, - verbose=verbose, - ) - if final_chunk is None: - raise ValueError("No data received from Ollama stream.") - - return final_chunk - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - """Call out to Ollama's generate endpoint. - - Args: - messages: The list of base messages to pass into the model. - stop: Optional list of stop words to use when generating. - - Returns: - Chat generations from the model - - Example: - .. code-block:: python - - response = ollama([ - HumanMessage(content="Tell me about the history of AI") - ]) - """ - - final_chunk = self._chat_stream_with_aggregation( - messages, - stop=stop, - run_manager=run_manager, - verbose=self.verbose, - **kwargs, - ) - chat_generation = ChatGeneration( - message=AIMessage(content=final_chunk.text), - generation_info=final_chunk.generation_info, - ) - return ChatResult(generations=[chat_generation]) - - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - """Call out to Ollama's generate endpoint. - - Args: - messages: The list of base messages to pass into the model. - stop: Optional list of stop words to use when generating. - - Returns: - Chat generations from the model - - Example: - .. code-block:: python - - response = ollama([ - HumanMessage(content="Tell me about the history of AI") - ]) - """ - - final_chunk = await self._achat_stream_with_aggregation( - messages, - stop=stop, - run_manager=run_manager, - verbose=self.verbose, - **kwargs, - ) - chat_generation = ChatGeneration( - message=AIMessage(content=final_chunk.text), - generation_info=final_chunk.generation_info, - ) - return ChatResult(generations=[chat_generation]) - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - try: - for stream_resp in self._create_chat_stream(messages, stop, **kwargs): - if stream_resp: - chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp) - if run_manager: - run_manager.on_llm_new_token( - chunk.text, - chunk=chunk, - verbose=self.verbose, - ) - yield chunk - except OllamaEndpointNotFoundError: - yield from self._legacy_stream(messages, stop, **kwargs) - - async def _astream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs): - if stream_resp: - chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp) - if run_manager: - await run_manager.on_llm_new_token( - chunk.text, - chunk=chunk, - verbose=self.verbose, - ) - yield chunk - - @deprecated("0.0.3", alternative="_stream") - def _legacy_stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - prompt = self._format_messages_as_text(messages) - for stream_resp in self._create_generate_stream(prompt, stop, **kwargs): - if stream_resp: - chunk = _stream_response_to_chat_generation_chunk(stream_resp) - if run_manager: - run_manager.on_llm_new_token( - chunk.text, - chunk=chunk, - verbose=self.verbose, - ) - yield chunk diff --git a/libs/community/langchain_community/chat_models/openai.py b/libs/community/langchain_community/chat_models/openai.py index 7b69019ba..b99726710 100644 --- a/libs/community/langchain_community/chat_models/openai.py +++ b/libs/community/langchain_community/chat_models/openai.py @@ -1,4 +1,7 @@ -"""OpenAI chat wrapper.""" +"""DO NOT USE; KEPT FOR BACKWARDS COMPAT. + +THIS MAY BE DELETED AT ANY POINT. +""" from __future__ import annotations @@ -237,7 +240,7 @@ def is_lc_serializable(cls) -> bool: openai_api_key: Optional[str] = Field(default=None, alias="api_key") """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" openai_api_base: Optional[str] = Field(default=None, alias="base_url") - """Base URL path for API requests, leave blank if not using a proxy or service + """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" openai_organization: Optional[str] = Field(default=None, alias="organization") """Automatically inferred from env var `OPENAI_ORG_ID` if not provided.""" @@ -246,7 +249,7 @@ def is_lc_serializable(cls) -> bool: request_timeout: Union[float, Tuple[float, float], Any, None] = Field( default=None, alias="timeout" ) - """Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or + """Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or None.""" max_retries: int = Field(default=2) """Maximum number of retries to make when generating.""" @@ -257,14 +260,14 @@ def is_lc_serializable(cls) -> bool: max_tokens: Optional[int] = None """Maximum number of tokens to generate.""" tiktoken_model_name: Optional[str] = None - """The model name to pass to tiktoken when using this class. - Tiktoken is used to count the number of tokens in documents to constrain - them to be under a certain limit. By default, when set to None, this will - be the same as the embedding model name. However, there are some cases - where you may want to use this Embedding class with a model name not - supported by tiktoken. This can include when using Azure embeddings or - when using one of the many model providers that expose an OpenAI-like - API but with different models. In those cases, in order to avoid erroring + """The model name to pass to tiktoken when using this class. + Tiktoken is used to count the number of tokens in documents to constrain + them to be under a certain limit. By default, when set to None, this will + be the same as the embedding model name. However, there are some cases + where you may want to use this Embedding class with a model name not + supported by tiktoken. This can include when using Azure embeddings or + when using one of the many model providers that expose an OpenAI-like + API but with different models. In those cases, in order to avoid erroring when tiktoken is called, you can specify a model name to use here.""" default_headers: Union[Mapping[str, str], None] = None default_query: Union[Mapping[str, object], None] = None diff --git a/libs/community/langchain_community/chat_models/perplexity.py b/libs/community/langchain_community/chat_models/perplexity.py deleted file mode 100644 index e7d57ea04..000000000 --- a/libs/community/langchain_community/chat_models/perplexity.py +++ /dev/null @@ -1,525 +0,0 @@ -"""Wrapper around Perplexity APIs.""" - -from __future__ import annotations - -import logging -from operator import itemgetter -from typing import ( - Any, - Dict, - Iterator, - List, - Literal, - Mapping, - Optional, - Tuple, - Type, - TypeVar, - Union, -) - -from langchain_core._api.deprecation import deprecated -from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.language_models import LanguageModelInput -from langchain_core.language_models.chat_models import ( - BaseChatModel, - generate_from_stream, -) -from langchain_core.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - BaseMessageChunk, - ChatMessage, - ChatMessageChunk, - FunctionMessageChunk, - HumanMessage, - HumanMessageChunk, - SystemMessage, - SystemMessageChunk, - ToolMessageChunk, -) -from langchain_core.messages.ai import UsageMetadata -from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser -from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough -from langchain_core.utils import from_env, get_pydantic_field_names -from langchain_core.utils.pydantic import ( - is_basemodel_subclass, -) -from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, model_validator -from typing_extensions import Self - -_BM = TypeVar("_BM", bound=BaseModel) -_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type] -_DictOrPydantic = Union[Dict, _BM] - -logger = logging.getLogger(__name__) - - -def _is_pydantic_class(obj: Any) -> bool: - return isinstance(obj, type) and is_basemodel_subclass(obj) - - -def _create_usage_metadata(token_usage: dict) -> UsageMetadata: - input_tokens = token_usage.get("prompt_tokens", 0) - output_tokens = token_usage.get("completion_tokens", 0) - total_tokens = token_usage.get("total_tokens", input_tokens + output_tokens) - return UsageMetadata( - input_tokens=input_tokens, - output_tokens=output_tokens, - total_tokens=total_tokens, - ) - - -@deprecated( - since="0.3.21", - removal="1.0", - alternative_import="langchain_perplexity.ChatPerplexity", -) -class ChatPerplexity(BaseChatModel): - """`Perplexity AI` Chat models API. - - Setup: - To use, you should have the ``openai`` python package installed, and the - environment variable ``PPLX_API_KEY`` set to your API key. - Any parameters that are valid to be passed to the openai.create call - can be passed in, even if not explicitly saved on this class. - - .. code-block:: bash - - pip install openai - export PPLX_API_KEY=your_api_key - - Key init args - completion params: - model: str - Name of the model to use. e.g. "llama-3.1-sonar-small-128k-online" - temperature: float - Sampling temperature to use. Default is 0.7 - max_tokens: Optional[int] - Maximum number of tokens to generate. - streaming: bool - Whether to stream the results or not. - - Key init args - client params: - pplx_api_key: Optional[str] - API key for PerplexityChat API. Default is None. - request_timeout: Optional[Union[float, Tuple[float, float]]] - Timeout for requests to PerplexityChat completion API. Default is None. - max_retries: int - Maximum number of retries to make when generating. - - See full list of supported init args and their descriptions in the params section. - - Instantiate: - .. code-block:: python - - from langchain_community.chat_models import ChatPerplexity - - llm = ChatPerplexity( - model="llama-3.1-sonar-small-128k-online", - temperature=0.7, - ) - - Invoke: - .. code-block:: python - - messages = [ - ("system", "You are a chatbot."), - ("user", "Hello!") - ] - llm.invoke(messages) - - Invoke with structured output: - .. code-block:: python - - from pydantic import BaseModel - - class StructuredOutput(BaseModel): - role: str - content: str - - llm.with_structured_output(StructuredOutput) - llm.invoke(messages) - - Invoke with perplexity-specific params: - .. code-block:: python - - llm.invoke(messages, extra_body={"search_recency_filter": "week"}) - - Stream: - .. code-block:: python - - for chunk in llm.stream(messages): - print(chunk.content) - - Token usage: - .. code-block:: python - - response = llm.invoke(messages) - response.usage_metadata - - Response metadata: - .. code-block:: python - - response = llm.invoke(messages) - response.response_metadata - - """ # noqa: E501 - - client: Any = None #: :meta private: - model: str = "llama-3.1-sonar-small-128k-online" - """Model name.""" - temperature: float = 0.7 - """What sampling temperature to use.""" - model_kwargs: Dict[str, Any] = Field(default_factory=dict) - """Holds any model parameters valid for `create` call not explicitly specified.""" - pplx_api_key: Optional[str] = Field( - default_factory=from_env("PPLX_API_KEY", default=None), alias="api_key" - ) - """Base URL path for API requests, - leave blank if not using a proxy or service emulator.""" - request_timeout: Optional[Union[float, Tuple[float, float]]] = Field( - None, alias="timeout" - ) - """Timeout for requests to PerplexityChat completion API. Default is None.""" - max_retries: int = 6 - """Maximum number of retries to make when generating.""" - streaming: bool = False - """Whether to stream the results or not.""" - max_tokens: Optional[int] = None - """Maximum number of tokens to generate.""" - - model_config = ConfigDict( - populate_by_name=True, - ) - - @property - def lc_secrets(self) -> Dict[str, str]: - return {"pplx_api_key": "PPLX_API_KEY"} - - @model_validator(mode="before") - @classmethod - def build_extra(cls, values: Dict[str, Any]) -> Any: - """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = get_pydantic_field_names(cls) - extra = values.get("model_kwargs", {}) - for field_name in list(values): - if field_name in extra: - raise ValueError(f"Found {field_name} supplied twice.") - if field_name not in all_required_field_names: - logger.warning( - f"""WARNING! {field_name} is not a default parameter. - {field_name} was transferred to model_kwargs. - Please confirm that {field_name} is what you intended.""" - ) - extra[field_name] = values.pop(field_name) - - invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) - if invalid_model_kwargs: - raise ValueError( - f"Parameters {invalid_model_kwargs} should be specified explicitly. " - f"Instead they were passed in as part of `model_kwargs` parameter." - ) - - values["model_kwargs"] = extra - return values - - @model_validator(mode="after") - def validate_environment(self) -> Self: - """Validate that api key and python package exists in environment.""" - try: - import openai - except ImportError: - raise ImportError( - "Could not import openai python package. " - "Please install it with `pip install openai`." - ) - try: - self.client = openai.OpenAI( - api_key=self.pplx_api_key, base_url="https://api.perplexity.ai" - ) - except AttributeError: - raise ValueError( - "`openai` has no `ChatCompletion` attribute, this is likely " - "due to an old version of the openai package. Try upgrading it " - "with `pip install --upgrade openai`." - ) - return self - - @property - def _default_params(self) -> Dict[str, Any]: - """Get the default parameters for calling PerplexityChat API.""" - return { - "max_tokens": self.max_tokens, - "stream": self.streaming, - "temperature": self.temperature, - **self.model_kwargs, - } - - def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]: - if isinstance(message, ChatMessage): - message_dict = {"role": message.role, "content": message.content} - elif isinstance(message, SystemMessage): - message_dict = {"role": "system", "content": message.content} - elif isinstance(message, HumanMessage): - message_dict = {"role": "user", "content": message.content} - elif isinstance(message, AIMessage): - message_dict = {"role": "assistant", "content": message.content} - else: - raise TypeError(f"Got unknown type {message}") - return message_dict - - def _create_message_dicts( - self, messages: List[BaseMessage], stop: Optional[List[str]] - ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: - params = dict(self._invocation_params) - if stop is not None: - if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") - params["stop"] = stop - message_dicts = [self._convert_message_to_dict(m) for m in messages] - return message_dicts, params - - def _convert_delta_to_message_chunk( - self, _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] - ) -> BaseMessageChunk: - role = _dict.get("role") - content = _dict.get("content") or "" - additional_kwargs: Dict = {} - if _dict.get("function_call"): - function_call = dict(_dict["function_call"]) - if "name" in function_call and function_call["name"] is None: - function_call["name"] = "" - additional_kwargs["function_call"] = function_call - if _dict.get("tool_calls"): - additional_kwargs["tool_calls"] = _dict["tool_calls"] - - if role == "user" or default_class == HumanMessageChunk: - return HumanMessageChunk(content=content) - elif role == "assistant" or default_class == AIMessageChunk: - return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) - elif role == "system" or default_class == SystemMessageChunk: - return SystemMessageChunk(content=content) - elif role == "function" or default_class == FunctionMessageChunk: - return FunctionMessageChunk(content=content, name=_dict["name"]) - elif role == "tool" or default_class == ToolMessageChunk: - return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"]) - elif role or default_class == ChatMessageChunk: - return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type] - else: - return default_class(content=content) # type: ignore[call-arg] - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs} - default_chunk_class = AIMessageChunk - params.pop("stream", None) - if stop: - params["stop_sequences"] = stop - stream_resp = self.client.chat.completions.create( - messages=message_dicts, stream=True, **params - ) - first_chunk = True - prev_total_usage: Optional[UsageMetadata] = None - for chunk in stream_resp: - if not isinstance(chunk, dict): - chunk = chunk.dict() - # Collect standard usage metadata (transform from aggregate to delta) - if total_usage := chunk.get("usage"): - lc_total_usage = _create_usage_metadata(total_usage) - if prev_total_usage: - usage_metadata: Optional[UsageMetadata] = { - "input_tokens": lc_total_usage["input_tokens"] - - prev_total_usage["input_tokens"], - "output_tokens": lc_total_usage["output_tokens"] - - prev_total_usage["output_tokens"], - "total_tokens": lc_total_usage["total_tokens"] - - prev_total_usage["total_tokens"], - } - else: - usage_metadata = lc_total_usage - prev_total_usage = lc_total_usage - else: - usage_metadata = None - if len(chunk["choices"]) == 0: - continue - choice = chunk["choices"][0] - - additional_kwargs = {} - if first_chunk: - additional_kwargs["citations"] = chunk.get("citations", []) - for attr in ["images", "related_questions"]: - if attr in chunk: - additional_kwargs[attr] = chunk[attr] - - chunk = self._convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) - - if isinstance(chunk, AIMessageChunk) and usage_metadata: - chunk.usage_metadata = usage_metadata - - if first_chunk: - chunk.additional_kwargs |= additional_kwargs - first_chunk = False - - finish_reason = choice.get("finish_reason") - generation_info = ( - dict(finish_reason=finish_reason) if finish_reason is not None else None - ) - default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) - if run_manager: - run_manager.on_llm_new_token(chunk.text, chunk=chunk) - yield chunk - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - if self.streaming: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - if stream_iter: - return generate_from_stream(stream_iter) - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs} - response = self.client.chat.completions.create(messages=message_dicts, **params) - if usage := getattr(response, "usage", None): - usage_metadata = _create_usage_metadata(usage.model_dump()) - else: - usage_metadata = None - - additional_kwargs = {"citations": response.citations} - for attr in ["images", "related_questions"]: - if hasattr(response, attr): - additional_kwargs[attr] = getattr(response, attr) - - message = AIMessage( - content=response.choices[0].message.content, - additional_kwargs=additional_kwargs, - usage_metadata=usage_metadata, - ) - return ChatResult(generations=[ChatGeneration(message=message)]) - - @property - def _invocation_params(self) -> Mapping[str, Any]: - """Get the parameters used to invoke the model.""" - pplx_creds: Dict[str, Any] = { - "model": self.model, - } - return {**pplx_creds, **self._default_params} - - @property - def _llm_type(self) -> str: - """Return type of chat model.""" - return "perplexitychat" - - def with_structured_output( - self, - schema: Optional[_DictOrPydanticClass] = None, - *, - method: Literal["json_schema"] = "json_schema", - include_raw: bool = False, - strict: Optional[bool] = None, - **kwargs: Any, - ) -> Runnable[LanguageModelInput, _DictOrPydantic]: - """Model wrapper that returns outputs formatted to match the given schema for Preplexity. - Currently, Preplexity only supports "json_schema" method for structured output - as per their official documentation: https://docs.perplexity.ai/guides/structured-outputs - - Args: - schema: - The output schema. Can be passed in as: - - - a JSON Schema, - - a TypedDict class, - - or a Pydantic class - - method: The method for steering model generation, currently only support: - - - "json_schema": Use the JSON Schema to parse the model output - - - include_raw: - If False then only the parsed structured output is returned. If - an error occurs during model output parsing it will be raised. If True - then both the raw model response (a BaseMessage) and the parsed model - response will be returned. If an error occurs during output parsing it - will be caught and returned as well. The final output is always a dict - with keys "raw", "parsed", and "parsing_error". - - kwargs: Additional keyword args aren't supported. - - Returns: - A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. - - | If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict. - - | If ``include_raw`` is True, then Runnable outputs a dict with keys: - - - "raw": BaseMessage - - "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above. - - "parsing_error": Optional[BaseException] - - """ # noqa: E501 - if method in ("function_calling", "json_mode"): - method = "json_schema" - if method == "json_schema": - if schema is None: - raise ValueError( - "schema must be specified when method is not 'json_schema'. " - "Received None." - ) - is_pydantic_schema = _is_pydantic_class(schema) - if is_pydantic_schema and hasattr( - schema, "model_json_schema" - ): # accounting for pydantic v1 and v2 - response_format = schema.model_json_schema() - elif is_pydantic_schema: - response_format = schema.schema() # type: ignore[union-attr] - elif isinstance(schema, dict): - response_format = schema - elif type(schema).__name__ == "_TypedDictMeta": - adapter = TypeAdapter(schema) # if use passes typeddict - response_format = adapter.json_schema() - - llm = self.bind( - response_format={ - "type": "json_schema", - "json_schema": {"schema": response_format}, - } - ) - output_parser = ( - PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] - if is_pydantic_schema - else JsonOutputParser() - ) - else: - raise ValueError( - f"Unrecognized method argument. Expected 'json_schema' Received:\ - '{method}'" - ) - - if include_raw: - parser_assign = RunnablePassthrough.assign( - parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None - ) - parser_none = RunnablePassthrough.assign(parsed=lambda _: None) - parser_with_fallback = parser_assign.with_fallbacks( - [parser_none], exception_key="parsing_error" - ) - return RunnableMap(raw=llm) | parser_with_fallback - else: - return llm | output_parser diff --git a/libs/community/langchain_community/chat_models/promptlayer_openai.py b/libs/community/langchain_community/chat_models/promptlayer_openai.py index ee6436267..dec101cfc 100644 --- a/libs/community/langchain_community/chat_models/promptlayer_openai.py +++ b/libs/community/langchain_community/chat_models/promptlayer_openai.py @@ -10,7 +10,7 @@ from langchain_core.messages import BaseMessage from langchain_core.outputs import ChatResult -from langchain_community.chat_models import ChatOpenAI +from langchain_community.chat_models.openai import ChatOpenAI class PromptLayerChatOpenAI(ChatOpenAI): @@ -37,7 +37,7 @@ class PromptLayerChatOpenAI(ChatOpenAI): openai = PromptLayerChatOpenAI(model="gpt-3.5-turbo") """ - pl_tags: Optional[List[str]] + pl_tags: Optional[List[str]] = None return_pl_id: Optional[bool] = False @classmethod diff --git a/libs/community/langchain_community/chat_models/sambanova.py b/libs/community/langchain_community/chat_models/sambanova.py deleted file mode 100644 index 1146a0374..000000000 --- a/libs/community/langchain_community/chat_models/sambanova.py +++ /dev/null @@ -1,2219 +0,0 @@ -import json -from operator import itemgetter -from typing import ( - Any, - Callable, - Dict, - Iterator, - List, - Literal, - Optional, - Sequence, - Tuple, - Type, - Union, - cast, -) - -import requests -from langchain_core._api.deprecation import deprecated -from langchain_core.callbacks import ( - CallbackManagerForLLMRun, -) -from langchain_core.language_models import LanguageModelInput -from langchain_core.language_models.chat_models import ( - BaseChatModel, - generate_from_stream, -) -from langchain_core.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - BaseMessageChunk, - ChatMessage, - HumanMessage, - SystemMessage, - ToolMessage, -) -from langchain_core.output_parsers import ( - JsonOutputParser, - PydanticOutputParser, -) -from langchain_core.output_parsers.base import OutputParserLike -from langchain_core.output_parsers.openai_tools import ( - JsonOutputKeyToolsParser, - PydanticToolsParser, - make_invalid_tool_call, - parse_tool_call, -) -from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough -from langchain_core.tools import BaseTool -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env -from langchain_core.utils.function_calling import convert_to_openai_tool -from langchain_core.utils.pydantic import is_basemodel_subclass -from pydantic import BaseModel, Field, SecretStr -from requests import Response - - -def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]: - """ - convert a BaseMessage to a dictionary with Role / content - - Args: - message: BaseMessage - - Returns: - messages_dict: role / content dict - """ - message_dict: Dict[str, Any] = {} - if isinstance(message, ChatMessage): - message_dict = {"role": message.role, "content": message.content} - elif isinstance(message, SystemMessage): - message_dict = {"role": "system", "content": message.content} - elif isinstance(message, HumanMessage): - message_dict = {"role": "user", "content": message.content} - elif isinstance(message, AIMessage): - message_dict = {"role": "assistant", "content": message.content} - if "tool_calls" in message.additional_kwargs: - message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] - if message_dict["content"] == "": - message_dict["content"] = None - elif isinstance(message, ToolMessage): - message_dict = { - "role": "tool", - "content": message.content, - "tool_call_id": message.tool_call_id, - } - else: - raise TypeError(f"Got unknown type {message}") - return message_dict - - -def _create_message_dicts(messages: List[BaseMessage]) -> List[Dict[str, Any]]: - """ - Convert a list of BaseMessages to a list of dictionaries with Role / content - - Args: - messages: list of BaseMessages - - Returns: - messages_dicts: list of role / content dicts - """ - message_dicts = [_convert_message_to_dict(m) for m in messages] - return message_dicts - - -def _is_pydantic_class(obj: Any) -> bool: - return isinstance(obj, type) and is_basemodel_subclass(obj) - - -@deprecated( - since="0.3.16", - removal="1.0", - alternative_import="langchain_sambanova.ChatSambaNovaCloud", -) -class ChatSambaNovaCloud(BaseChatModel): - """ - SambaNova Cloud chat model. - - Setup: - To use, you should have the environment variables: - `SAMBANOVA_URL` set with your SambaNova Cloud URL. - `SAMBANOVA_API_KEY` set with your SambaNova Cloud API Key. - http://cloud.sambanova.ai/ - Example: - .. code-block:: python - ChatSambaNovaCloud( - sambanova_url = SambaNova cloud endpoint URL, - sambanova_api_key = set with your SambaNova cloud API key, - model = model name, - max_tokens = max number of tokens to generate, - temperature = model temperature, - top_p = model top p, - top_k = model top k, - stream_options = include usage to get generation metrics - ) - - Key init args — completion params: - model: str - The name of the model to use, e.g., Meta-Llama-3-70B-Instruct. - streaming: bool - Whether to use streaming handler when using non streaming methods - max_tokens: int - max tokens to generate - temperature: float - model temperature - top_p: float - model top p - top_k: int - model top k - stream_options: dict - stream options, include usage to get generation metrics - - Key init args — client params: - sambanova_url: str - SambaNova Cloud Url - sambanova_api_key: str - SambaNova Cloud api key - - Instantiate: - .. code-block:: python - - from langchain_community.chat_models import ChatSambaNovaCloud - - chat = ChatSambaNovaCloud( - sambanova_url = SambaNova cloud endpoint URL, - sambanova_api_key = set with your SambaNova cloud API key, - model = model name, - max_tokens = max number of tokens to generate, - temperature = model temperature, - top_p = model top p, - top_k = model top k, - stream_options = include usage to get generation metrics - ) - - Invoke: - .. code-block:: python - - messages = [ - SystemMessage(content="your are an AI assistant."), - HumanMessage(content="tell me a joke."), - ] - response = chat.invoke(messages) - - Stream: - .. code-block:: python - - for chunk in chat.stream(messages): - print(chunk.content, end="", flush=True) - - Async: - .. code-block:: python - - response = chat.ainvoke(messages) - await response - - Tool calling: - .. code-block:: python - - from pydantic import BaseModel, Field - - class GetWeather(BaseModel): - '''Get the current weather in a given location''' - - location: str = Field( - ..., - description="The city and state, e.g. Los Angeles, CA" - ) - - llm_with_tools = llm.bind_tools([GetWeather, GetPopulation]) - ai_msg = llm_with_tools.invoke("Should I bring my umbrella today in LA?") - ai_msg.tool_calls - - .. code-block:: none - - [ - { - 'name': 'GetWeather', - 'args': {'location': 'Los Angeles, CA'}, - 'id': 'call_adf61180ea2b4d228a' - } - ] - - Structured output: - .. code-block:: python - - from typing import Optional - - from pydantic import BaseModel, Field - - class Joke(BaseModel): - '''Joke to tell user.''' - - setup: str = Field(description="The setup of the joke") - punchline: str = Field(description="The punchline to the joke") - - structured_model = llm.with_structured_output(Joke) - structured_model.invoke("Tell me a joke about cats") - - .. code-block:: python - - Joke(setup="Why did the cat join a band?", - punchline="Because it wanted to be the purr-cussionist!") - - See `ChatSambanovaCloud.with_structured_output()` for more. - - Token usage: - .. code-block:: python - - response = chat.invoke(messages) - print(response.response_metadata["usage"]["prompt_tokens"] - print(response.response_metadata["usage"]["total_tokens"] - - Response metadata - .. code-block:: python - - response = chat.invoke(messages) - print(response.response_metadata) - - """ - - sambanova_url: str = Field(default="") - """SambaNova Cloud Url""" - - sambanova_api_key: SecretStr = Field(default=SecretStr("")) - """SambaNova Cloud api key""" - - model: str = Field(default="Meta-Llama-3.1-8B-Instruct") - """The name of the model""" - - streaming: bool = Field(default=False) - """Whether to use streaming handler when using non streaming methods""" - - max_tokens: int = Field(default=1024) - """max tokens to generate""" - - temperature: float = Field(default=0.7) - """model temperature""" - - top_p: Optional[float] = Field(default=None) - """model top p""" - - top_k: Optional[int] = Field(default=None) - """model top k""" - - stream_options: Dict[str, Any] = Field(default={"include_usage": True}) - """stream options, include usage to get generation metrics""" - - additional_headers: Dict[str, Any] = Field(default={}) - """Additional headers to sent in request""" - - class Config: - populate_by_name = True - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether this model can be serialized by Langchain.""" - return False - - @property - def lc_secrets(self) -> Dict[str, str]: - return {"sambanova_api_key": "sambanova_api_key"} - - @property - def _identifying_params(self) -> Dict[str, Any]: - """Return a dictionary of identifying parameters. - - This information is used by the LangChain callback system, which - is used for tracing purposes make it possible to monitor LLMs. - """ - return { - "model": self.model, - "streaming": self.streaming, - "max_tokens": self.max_tokens, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "stream_options": self.stream_options, - } - - @property - def _llm_type(self) -> str: - """Get the type of language model used by this chat model.""" - return "sambanovacloud-chatmodel" - - def __init__(self, **kwargs: Any) -> None: - """init and validate environment variables""" - kwargs["sambanova_url"] = get_from_dict_or_env( - kwargs, - "sambanova_url", - "SAMBANOVA_URL", - default="https://api.sambanova.ai/v1/chat/completions", - ) - kwargs["sambanova_api_key"] = convert_to_secret_str( - get_from_dict_or_env(kwargs, "sambanova_api_key", "SAMBANOVA_API_KEY") - ) - super().__init__(**kwargs) - - def bind_tools( - self, - tools: Sequence[Union[Dict[str, Any], Type[Any], Callable[..., Any], BaseTool]], - *, - tool_choice: Optional[Union[Dict[str, Any], bool, str]] = None, - parallel_tool_calls: Optional[bool] = False, - **kwargs: Any, - ) -> Runnable[LanguageModelInput, AIMessage]: - """Bind tool-like objects to this chat model - - tool_choice: does not currently support "any", choice like - should be one of ["auto", "none", "required"] - """ - - formatted_tools = [convert_to_openai_tool(tool) for tool in tools] - - if tool_choice: - if isinstance(tool_choice, str): - # tool_choice is a tool/function name - if tool_choice not in ("auto", "none", "required"): - tool_choice = "auto" - elif isinstance(tool_choice, bool): - if tool_choice: - tool_choice = "required" - elif isinstance(tool_choice, dict): - raise ValueError( - "tool_choice must be one of ['auto', 'none', 'required']" - ) - else: - raise ValueError( - f"Unrecognized tool_choice type. Expected str, bool" - f"Received: {tool_choice}" - ) - else: - tool_choice = "auto" - kwargs["tool_choice"] = tool_choice - kwargs["parallel_tool_calls"] = parallel_tool_calls - return super().bind(tools=formatted_tools, **kwargs) - - def with_structured_output( - self, - schema: Optional[Union[Dict[str, Any], Type[BaseModel]]] = None, - *, - method: Literal[ - "function_calling", "json_mode", "json_schema" - ] = "function_calling", - include_raw: bool = False, - **kwargs: Any, - ) -> Runnable[LanguageModelInput, Union[Dict[str, Any], BaseModel]]: - """Model wrapper that returns outputs formatted to match the given schema. - - Args: - schema: - The output schema. Can be passed in as: - - an OpenAI function/tool schema, - - a JSON Schema, - - a TypedDict class, - - or a Pydantic.BaseModel class. - If `schema` is a Pydantic class then the model output will be a - Pydantic instance of that class, and the model-generated fields will be - validated by the Pydantic class. Otherwise the model output will be a - dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool` - for more on how to properly specify types and descriptions of - schema fields when specifying a Pydantic or TypedDict class. - - method: - The method for steering model generation, either "function_calling" - "json_mode" or "json_schema". - If "function_calling" then the schema will be converted - to an OpenAI function and the returned model will make use of the - function-calling API. If "json_mode" or "json_schema" then OpenAI's - JSON mode will be used. - Note that if using "json_mode" or "json_schema" then you must include instructions - for formatting the output into the desired schema into the model call. - - include_raw: - If False then only the parsed structured output is returned. If - an error occurs during model output parsing it will be raised. If True - then both the raw model response (a BaseMessage) and the parsed model - response will be returned. If an error occurs during output parsing it - will be caught and returned as well. The final output is always a dict - with keys "raw", "parsed", and "parsing_error". - - Returns: - A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. - - If `include_raw` is False and `schema` is a Pydantic class, Runnable outputs - an instance of `schema` (i.e., a Pydantic object). - - Otherwise, if `include_raw` is False then Runnable outputs a dict. - - If `include_raw` is True, then Runnable outputs a dict with keys: - - `"raw"`: BaseMessage - - `"parsed"`: None if there was a parsing error, otherwise the type depends on the `schema` as described above. - - `"parsing_error"`: Optional[BaseException] - - Example: schema=Pydantic class, method="function_calling", include_raw=False: - .. code-block:: python - - from typing import Optional - - from langchain_community.chat_models import ChatSambaNovaCloud - from pydantic import BaseModel, Field - - - class AnswerWithJustification(BaseModel): - '''An answer to the user question along with justification for the answer.''' - - answer: str - justification: str = Field( - description="A justification for the answer." - ) - - - llm = ChatSambaNovaCloud(model="Meta-Llama-3.1-70B-Instruct", temperature=0) - structured_llm = llm.with_structured_output(AnswerWithJustification) - - structured_llm.invoke( - "What weighs more a pound of bricks or a pound of feathers" - ) - - # -> AnswerWithJustification( - # answer='They weigh the same', - # justification='A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same.' - # ) - - Example: schema=Pydantic class, method="function_calling", include_raw=True: - .. code-block:: python - - from langchain_community.chat_models import ChatSambaNovaCloud - from pydantic import BaseModel - - - class AnswerWithJustification(BaseModel): - '''An answer to the user question along with justification for the answer.''' - - answer: str - justification: str - - - llm = ChatSambaNovaCloud(model="Meta-Llama-3.1-70B-Instruct", temperature=0) - structured_llm = llm.with_structured_output( - AnswerWithJustification, include_raw=True - ) - - structured_llm.invoke( - "What weighs more a pound of bricks or a pound of feathers" - ) - # -> { - # 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'function': {'arguments': '{"answer": "They weigh the same.", "justification": "A pound is a unit of weight or mass, so one pound of bricks and one pound of feathers both weigh the same amount."}', 'name': 'AnswerWithJustification'}, 'id': 'call_17a431fc6a4240e1bd', 'type': 'function'}]}, response_metadata={'finish_reason': 'tool_calls', 'usage': {'acceptance_rate': 5, 'completion_tokens': 53, 'completion_tokens_after_first_per_sec': 343.7964936837758, 'completion_tokens_after_first_per_sec_first_ten': 439.1205661878638, 'completion_tokens_per_sec': 162.8511306784833, 'end_time': 1731527851.0698032, 'is_last_response': True, 'prompt_tokens': 213, 'start_time': 1731527850.7137961, 'time_to_first_token': 0.20475482940673828, 'total_latency': 0.32545061111450196, 'total_tokens': 266, 'total_tokens_per_sec': 817.3283162354066}, 'model_name': 'Meta-Llama-3.1-70B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1731527850}, id='95667eaf-447f-4b53-bb6e-b6e1094ded88', tool_calls=[{'name': 'AnswerWithJustification', 'args': {'answer': 'They weigh the same.', 'justification': 'A pound is a unit of weight or mass, so one pound of bricks and one pound of feathers both weigh the same amount.'}, 'id': 'call_17a431fc6a4240e1bd', 'type': 'tool_call'}]), - # 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='A pound is a unit of weight or mass, so one pound of bricks and one pound of feathers both weigh the same amount.'), - # 'parsing_error': None - # } - - Example: schema=TypedDict class, method="function_calling", include_raw=False: - .. code-block:: python - - # IMPORTANT: If you are using Python <=3.8, you need to import Annotated - # from typing_extensions, not from typing. - from typing_extensions import Annotated, TypedDict - - from langchain_community.chat_models import ChatSambaNovaCloud - - - class AnswerWithJustification(TypedDict): - '''An answer to the user question along with justification for the answer.''' - - answer: str - justification: Annotated[ - Optional[str], None, "A justification for the answer." - ] - - - llm = ChatSambaNovaCloud(model="Meta-Llama-3.1-70B-Instruct", temperature=0) - structured_llm = llm.with_structured_output(AnswerWithJustification) - - structured_llm.invoke( - "What weighs more a pound of bricks or a pound of feathers" - ) - # -> { - # 'answer': 'They weigh the same', - # 'justification': 'A pound is a unit of weight or mass, so one pound of bricks and one pound of feathers both weigh the same amount.' - # } - - Example: schema=OpenAI function schema, method="function_calling", include_raw=False: - .. code-block:: python - - from langchain_community.chat_models import ChatSambaNovaCloud - - oai_schema = { - 'name': 'AnswerWithJustification', - 'description': 'An answer to the user question along with justification for the answer.', - 'parameters': { - 'type': 'object', - 'properties': { - 'answer': {'type': 'string'}, - 'justification': {'description': 'A justification for the answer.', 'type': 'string'} - }, - 'required': ['answer'] - } - } - - llm = ChatSambaNovaCloud(model="Meta-Llama-3.1-70B-Instruct", temperature=0) - structured_llm = llm.with_structured_output(oai_schema) - - structured_llm.invoke( - "What weighs more a pound of bricks or a pound of feathers" - ) - # -> { - # 'answer': 'They weigh the same', - # 'justification': 'A pound is a unit of weight or mass, so one pound of bricks and one pound of feathers both weigh the same amount.' - # } - - Example: schema=Pydantic class, method="json_mode", include_raw=True: - .. code-block:: - - from langchain_community.chat_models import ChatSambaNovaCloud - from pydantic import BaseModel - - class AnswerWithJustification(BaseModel): - answer: str - justification: str - - llm = ChatSambaNovaCloud(model="Meta-Llama-3.1-70B-Instruct", temperature=0) - structured_llm = llm.with_structured_output( - AnswerWithJustification, - method="json_mode", - include_raw=True - ) - - structured_llm.invoke( - "Answer the following question. " - "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" - "What's heavier a pound of bricks or a pound of feathers?" - ) - # -> { - # 'raw': AIMessage(content='{\n "answer": "They are the same weight",\n "justification": "A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities."\n}', additional_kwargs={}, response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 5.3125, 'completion_tokens': 79, 'completion_tokens_after_first_per_sec': 292.65701089829776, 'completion_tokens_after_first_per_sec_first_ten': 346.43324678555325, 'completion_tokens_per_sec': 200.012158915008, 'end_time': 1731528071.1708555, 'is_last_response': True, 'prompt_tokens': 70, 'start_time': 1731528070.737394, 'time_to_first_token': 0.16693782806396484, 'total_latency': 0.3949759876026827, 'total_tokens': 149, 'total_tokens_per_sec': 377.2381225105847}, 'model_name': 'Meta-Llama-3.1-70B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1731528070}, id='83208297-3eb9-4021-a856-ca78a15758df'), - # 'parsed': AnswerWithJustification(answer='They are the same weight', justification='A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities.'), - # 'parsing_error': None - # } - - Example: schema=None, method="json_mode", include_raw=True: - .. code-block:: - - from langchain_community.chat_models import ChatSambaNovaCloud - - llm = ChatSambaNovaCloud(model="Meta-Llama-3.1-70B-Instruct", temperature=0) - structured_llm = llm.with_structured_output(method="json_mode", include_raw=True) - - structured_llm.invoke( - "Answer the following question. " - "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" - "What's heavier a pound of bricks or a pound of feathers?" - ) - # -> { - # 'raw': AIMessage(content='{\n "answer": "They are the same weight",\n "justification": "A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities."\n}', additional_kwargs={}, response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 4.722222222222222, 'completion_tokens': 79, 'completion_tokens_after_first_per_sec': 357.1315485254867, 'completion_tokens_after_first_per_sec_first_ten': 416.83279609305305, 'completion_tokens_per_sec': 240.92819585198137, 'end_time': 1731528164.8474727, 'is_last_response': True, 'prompt_tokens': 70, 'start_time': 1731528164.4906917, 'time_to_first_token': 0.13837409019470215, 'total_latency': 0.3278985247892492, 'total_tokens': 149, 'total_tokens_per_sec': 454.4088757208256}, 'model_name': 'Meta-Llama-3.1-70B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1731528164}, id='15261eaf-8a25-42ef-8ed5-f63d8bf5b1b0'), - # 'parsed': { - # 'answer': 'They are the same weight', - # 'justification': 'A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities.'}, - # }, - # 'parsing_error': None - # } - - Example: schema=None, method="json_schema", include_raw=True: - .. code-block:: - - from langchain_community.chat_models import ChatSambaNovaCloud - - class AnswerWithJustification(BaseModel): - answer: str - justification: str - - llm = ChatSambaNovaCloud(model="Meta-Llama-3.1-70B-Instruct", temperature=0) - structured_llm = llm.with_structured_output(AnswerWithJustification, method="json_schema", include_raw=True) - - structured_llm.invoke( - "Answer the following question. " - "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" - "What's heavier a pound of bricks or a pound of feathers?" - ) - # -> { - # 'raw': AIMessage(content='{\n "answer": "They are the same weight",\n "justification": "A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities."\n}', additional_kwargs={}, response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 5.3125, 'completion_tokens': 79, 'completion_tokens_after_first_per_sec': 292.65701089829776, 'completion_tokens_after_first_per_sec_first_ten': 346.43324678555325, 'completion_tokens_per_sec': 200.012158915008, 'end_time': 1731528071.1708555, 'is_last_response': True, 'prompt_tokens': 70, 'start_time': 1731528070.737394, 'time_to_first_token': 0.16693782806396484, 'total_latency': 0.3949759876026827, 'total_tokens': 149, 'total_tokens_per_sec': 377.2381225105847}, 'model_name': 'Meta-Llama-3.1-70B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1731528070}, id='83208297-3eb9-4021-a856-ca78a15758df'), - # 'parsed': AnswerWithJustification(answer='They are the same weight', justification='A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities.'), - # 'parsing_error': None - # } - """ # noqa: E501 - if kwargs: - raise ValueError(f"Received unsupported arguments {kwargs}") - is_pydantic_schema = _is_pydantic_class(schema) - if method == "function_calling": - if schema is None: - raise ValueError( - "`schema` must be specified when method is `function_calling`. " - "Received None." - ) - tool_name = convert_to_openai_tool(schema)["function"]["name"] - llm = self.bind_tools([schema], tool_choice=tool_name) - if is_pydantic_schema: - output_parser: OutputParserLike[Any] = PydanticToolsParser( - tools=[schema], # type: ignore[list-item] - first_tool_only=True, - ) - else: - output_parser = JsonOutputKeyToolsParser( - key_name=tool_name, first_tool_only=True - ) - elif method == "json_mode": - llm = self - # TODO bind response format when json mode available by API - # llm = self.bind(response_format={"type": "json_object"}) - if is_pydantic_schema: - schema = cast(Type[BaseModel], schema) - output_parser = PydanticOutputParser(pydantic_object=schema) - else: - output_parser = JsonOutputParser() - - elif method == "json_schema": - if schema is None: - raise ValueError( - "`schema` must be specified when method is not `json_mode`. " - "Received None." - ) - llm = self - # TODO bind response format when json schema available by API, - # update example - # llm = self.bind( - # response_format={"type": "json_object", "json_schema": schema} - # ) - if is_pydantic_schema: - schema = cast(Type[BaseModel], schema) - output_parser = PydanticOutputParser(pydantic_object=schema) - else: - output_parser = JsonOutputParser() - else: - raise ValueError( - f"Unrecognized method argument. Expected one of `function_calling` or " - f"`json_mode`. Received: `{method}`" - ) - - if include_raw: - parser_assign = RunnablePassthrough.assign( - parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None - ) - parser_none = RunnablePassthrough.assign(parsed=lambda _: None) - parser_with_fallback = parser_assign.with_fallbacks( - [parser_none], exception_key="parsing_error" - ) - return RunnableMap(raw=llm) | parser_with_fallback - else: - return llm | output_parser - - def _handle_request( - self, - messages_dicts: List[Dict[str, Any]], - stop: Optional[List[str]] = None, - streaming: bool = False, - **kwargs: Any, - ) -> Response: - """ - Performs a post request to the LLM API. - - Args: - messages_dicts: List of role / content dicts to use as input. - stop: list of stop tokens - streaming: wether to do a streaming call - - Returns: - An iterator of response dicts. - """ - if streaming: - data = { - "messages": messages_dicts, - "max_tokens": self.max_tokens, - "stop": stop, - "model": self.model, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "stream": True, - "stream_options": self.stream_options, - **kwargs, - } - else: - data = { - "messages": messages_dicts, - "max_tokens": self.max_tokens, - "stop": stop, - "model": self.model, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - **kwargs, - } - http_session = requests.Session() - response = http_session.post( - self.sambanova_url, - headers={ - "Authorization": f"Bearer {self.sambanova_api_key.get_secret_value()}", - "Content-Type": "application/json", - **self.additional_headers, - }, - json=data, - stream=streaming, - ) - if response.status_code != 200: - raise RuntimeError( - f"Sambanova /complete call failed with status code " - f"{response.status_code}.", - f"{response.text}.", - ) - return response - - def _process_response(self, response: Response) -> AIMessage: - """ - Process a non streaming response from the api - - Args: - response: A request Response object - - Returns - generation: an AIMessage with model generation - """ - try: - response_dict = response.json() - if response_dict.get("error"): - raise RuntimeError( - f"Sambanova /complete call failed with status code " - f"{response.status_code}.", - f"{response_dict}.", - ) - except Exception as e: - raise RuntimeError( - f"Sambanova /complete call failed couldn't get JSON response {e}" - f"response: {response.text}" - ) - content = response_dict["choices"][0]["message"].get("content", "") - if content is None: - content = "" - additional_kwargs: Dict[str, Any] = {} - tool_calls = [] - invalid_tool_calls = [] - raw_tool_calls = response_dict["choices"][0]["message"].get("tool_calls") - if raw_tool_calls: - additional_kwargs["tool_calls"] = raw_tool_calls - for raw_tool_call in raw_tool_calls: - if isinstance(raw_tool_call["function"]["arguments"], dict): - raw_tool_call["function"]["arguments"] = json.dumps( - raw_tool_call["function"].get("arguments", {}) - ) - try: - tool_calls.append(parse_tool_call(raw_tool_call, return_id=True)) - except Exception as e: - invalid_tool_calls.append( - make_invalid_tool_call(raw_tool_call, str(e)) - ) - message = AIMessage( - content=content, - additional_kwargs=additional_kwargs, - tool_calls=tool_calls, - invalid_tool_calls=invalid_tool_calls, - response_metadata={ - "finish_reason": response_dict["choices"][0]["finish_reason"], - "usage": response_dict.get("usage"), - "model_name": response_dict["model"], - "system_fingerprint": response_dict["system_fingerprint"], - "created": response_dict["created"], - }, - id=response_dict["id"], - ) - return message - - def _process_stream_response( - self, response: Response - ) -> Iterator[BaseMessageChunk]: - """ - Process a streaming response from the api - - Args: - response: An iterable request Response object - - Yields: - generation: an AIMessageChunk with model partial generation - """ - try: - import sseclient - except ImportError: - raise ImportError( - "could not import sseclient library" - "Please install it with `pip install sseclient-py`." - ) - - client = sseclient.SSEClient(response) - - for event in client.events(): - if event.event == "error_event": - raise RuntimeError( - f"Sambanova /complete call failed with status code " - f"{response.status_code}." - f"{event.data}." - ) - - try: - # check if the response is a final event - # in that case event data response is '[DONE]' - if event.data != "[DONE]": - if isinstance(event.data, str): - data = json.loads(event.data) - else: - raise RuntimeError( - f"Sambanova /complete call failed with status code " - f"{response.status_code}." - f"{event.data}." - ) - if data.get("error"): - raise RuntimeError( - f"Sambanova /complete call failed with status code " - f"{response.status_code}." - f"{event.data}." - ) - if len(data["choices"]) > 0: - finish_reason = data["choices"][0].get("finish_reason") - content = data["choices"][0]["delta"]["content"] - id = data["id"] - chunk = AIMessageChunk( - content=content, id=id, additional_kwargs={} - ) - else: - content = "" - id = data["id"] - metadata = { - "finish_reason": finish_reason, - "usage": data.get("usage"), - "model_name": data["model"], - "system_fingerprint": data["system_fingerprint"], - "created": data["created"], - } - chunk = AIMessageChunk( - content=content, - id=id, - response_metadata=metadata, - additional_kwargs={}, - ) - yield chunk - - except Exception as e: - raise RuntimeError( - f"Error getting content chunk raw streamed response: {e}" - f"data: {event.data}" - ) - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - """ - Call SambaNovaCloud models. - - Args: - messages: the prompt composed of a list of messages. - stop: a list of strings on which the model should stop generating. - If generation stops due to a stop token, the stop token itself - SHOULD BE INCLUDED as part of the output. This is not enforced - across models right now, but it's a good practice to follow since - it makes it much easier to parse the output of the model - downstream and understand why generation stopped. - run_manager: A run manager with callbacks for the LLM. - - Returns: - result: ChatResult with model generation - """ - if self.streaming: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - if stream_iter: - return generate_from_stream(stream_iter) - messages_dicts = _create_message_dicts(messages) - response = self._handle_request(messages_dicts, stop, streaming=False, **kwargs) - message = self._process_response(response) - generation = ChatGeneration( - message=message, - generation_info={ - "finish_reason": message.response_metadata["finish_reason"] - }, - ) - return ChatResult(generations=[generation]) - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - """ - Stream the output of the SambaNovaCloud chat model. - - Args: - messages: the prompt composed of a list of messages. - stop: a list of strings on which the model should stop generating. - If generation stops due to a stop token, the stop token itself - SHOULD BE INCLUDED as part of the output. This is not enforced - across models right now, but it's a good practice to follow since - it makes it much easier to parse the output of the model - downstream and understand why generation stopped. - run_manager: A run manager with callbacks for the LLM. - - Yields: - chunk: ChatGenerationChunk with model partial generation - """ - messages_dicts = _create_message_dicts(messages) - response = self._handle_request(messages_dicts, stop, streaming=True, **kwargs) - for ai_message_chunk in self._process_stream_response(response): - chunk = ChatGenerationChunk(message=ai_message_chunk) - if run_manager: - run_manager.on_llm_new_token(chunk.text, chunk=chunk) - yield chunk - - -@deprecated( - since="0.3.16", - removal="1.0", - alternative_import="langchain_sambanova.ChatSambaStudio", -) -class ChatSambaStudio(BaseChatModel): - """ - SambaStudio chat model. - - Setup: - To use, you should have the environment variables: - `SAMBASTUDIO_URL` set with your SambaStudio deployed endpoint URL. - `SAMBASTUDIO_API_KEY` set with your SambaStudio deployed endpoint Key. - https://docs.sambanova.ai/sambastudio/latest/index.html - Example: - - .. code-block:: python - - ChatSambaStudio( - sambastudio_url = set with your SambaStudio deployed endpoint URL, - sambastudio_api_key = set with your SambaStudio deployed endpoint Key. - model = model or expert name (set for Bundle endpoints), - max_tokens = max number of tokens to generate, - temperature = model temperature, - top_p = model top p, - top_k = model top k, - do_sample = wether to do sample - process_prompt = wether to process prompt - (set for Bundle generic v1 and v2 endpoints) - stream_options = include usage to get generation metrics - special_tokens = start, start_role, end_role, end special tokens - (set for Bundle generic v1 and v2 endpoints when process prompt - set to false or for StandAlone v1 and v2 endpoints) - model_kwargs: Optional = Extra Key word arguments to pass to the model. - ) - - Key init args — completion params: - model: str - The name of the model to use, e.g., Meta-Llama-3-70B-Instruct-4096 - (set for Bundle endpoints). - streaming: bool - Whether to use streaming - max_tokens: inthandler when using non streaming methods - max tokens to generate - temperature: float - model temperature - top_p: float - model top p - top_k: int - model top k - do_sample: bool - wether to do sample - process_prompt: - wether to process prompt (set for Bundle generic v1 and v2 endpoints) - stream_options: dict - stream options, include usage to get generation metrics - special_tokens: dict - start, start_role, end_role and end special tokens - (set for Bundle generic v1 and v2 endpoints when process prompt set to false - or for StandAlone v1 and v2 endpoints) default to llama3 special tokens - model_kwargs: dict - Extra Key word arguments to pass to the model. - - Key init args — client params: - sambastudio_url: str - SambaStudio endpoint Url - sambastudio_api_key: str - SambaStudio endpoint api key - - Instantiate: - .. code-block:: python - - from langchain_community.chat_models import ChatSambaStudio - - chat = ChatSambaStudio=( - sambastudio_url = set with your SambaStudio deployed endpoint URL, - sambastudio_api_key = set with your SambaStudio deployed endpoint Key. - model = model or expert name (set for Bundle endpoints), - max_tokens = max number of tokens to generate, - temperature = model temperature, - top_p = model top p, - top_k = model top k, - do_sample = wether to do sample - process_prompt = wether to process prompt - (set for Bundle generic v1 and v2 endpoints) - stream_options = include usage to get generation metrics - special_tokens = start, start_role, end_role, and special tokens - (set for Bundle generic v1 and v2 endpoints when process prompt - set to false or for StandAlone v1 and v2 endpoints) - model_kwargs: Optional = Extra Key word arguments to pass to the model. - ) - - Invoke: - .. code-block:: python - - messages = [ - SystemMessage(content="your are an AI assistant."), - HumanMessage(content="tell me a joke."), - ] - response = chat.invoke(messages) - - Stream: - .. code-block:: python - - for chunk in chat.stream(messages): - print(chunk.content, end="", flush=True) - - Async: - .. code-block:: python - - response = chat.ainvoke(messages) - await response - - Tool calling: - .. code-block:: python - - from pydantic import BaseModel, Field - - class GetWeather(BaseModel): - '''Get the current weather in a given location''' - - location: str = Field( - ..., - description="The city and state, e.g. Los Angeles, CA" - ) - - llm_with_tools = llm.bind_tools([GetWeather, GetPopulation]) - ai_msg = llm_with_tools.invoke("Should I bring my umbrella today in LA?") - ai_msg.tool_calls - - .. code-block:: python - - [ - { - 'name': 'GetWeather', - 'args': {'location': 'Los Angeles, CA'}, - 'id': 'call_adf61180ea2b4d228a' - } - ] - - Structured output: - .. code-block:: python - - from typing import Optional - - from pydantic import BaseModel, Field - - class Joke(BaseModel): - '''Joke to tell user.''' - - setup: str = Field(description="The setup of the joke") - punchline: str = Field(description="The punchline to the joke") - - structured_model = llm.with_structured_output(Joke) - structured_model.invoke("Tell me a joke about cats") - - .. code-block:: python - - Joke(setup="Why did the cat join a band?", - punchline="Because it wanted to be the purr-cussionist!") - - See `ChatSambaStudio.with_structured_output()` for more. - - Token usage: - .. code-block:: python - - response = chat.invoke(messages) - print(response.response_metadata["usage"]["prompt_tokens"] - print(response.response_metadata["usage"]["total_tokens"] - - Response metadata - .. code-block:: python - - response = chat.invoke(messages) - print(response.response_metadata) - """ - - sambastudio_url: str = Field(default="") - """SambaStudio Url""" - - sambastudio_api_key: SecretStr = Field(default=SecretStr("")) - """SambaStudio api key""" - - base_url: str = Field(default="", exclude=True) - """SambaStudio non streaming Url""" - - streaming_url: str = Field(default="", exclude=True) - """SambaStudio streaming Url""" - - model: Optional[str] = Field(default=None) - """The name of the model or expert to use (for Bundle endpoints)""" - - streaming: bool = Field(default=False) - """Whether to use streaming handler when using non streaming methods""" - - max_tokens: int = Field(default=1024) - """max tokens to generate""" - - temperature: Optional[float] = Field(default=0.7) - """model temperature""" - - top_p: Optional[float] = Field(default=None) - """model top p""" - - top_k: Optional[int] = Field(default=None) - """model top k""" - - do_sample: Optional[bool] = Field(default=None) - """whether to do sampling""" - - process_prompt: Optional[bool] = Field(default=True) - """whether process prompt (for Bundle generic v1 and v2 endpoints)""" - - stream_options: Dict[str, Any] = Field(default={"include_usage": True}) - """stream options, include usage to get generation metrics""" - - special_tokens: Dict[str, Any] = Field( - default={ - "start": "<|begin_of_text|>", - "start_role": "<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>", - "end_role": "<|eot_id|>", - "end": "<|start_header_id|>assistant<|end_header_id|>\n", - } - ) - """start, start_role, end_role and end special tokens - (set for Bundle generic v1 and v2 endpoints when process prompt set to false - or for StandAlone v1 and v2 endpoints) - default to llama3 special tokens""" - - model_kwargs: Optional[Dict[str, Any]] = None - """Key word arguments to pass to the model.""" - - additional_headers: Dict[str, Any] = Field(default={}) - """Additional headers to send in request""" - - class Config: - populate_by_name = True - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether this model can be serialized by Langchain.""" - return False - - @property - def lc_secrets(self) -> Dict[str, str]: - return { - "sambastudio_url": "sambastudio_url", - "sambastudio_api_key": "sambastudio_api_key", - } - - @property - def _identifying_params(self) -> Dict[str, Any]: - """Return a dictionary of identifying parameters. - - This information is used by the LangChain callback system, which - is used for tracing purposes make it possible to monitor LLMs. - """ - return { - "model": self.model, - "streaming": self.streaming, - "max_tokens": self.max_tokens, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "do_sample": self.do_sample, - "process_prompt": self.process_prompt, - "stream_options": self.stream_options, - "special_tokens": self.special_tokens, - "model_kwargs": self.model_kwargs, - } - - @property - def _llm_type(self) -> str: - """Get the type of language model used by this chat model.""" - return "sambastudio-chatmodel" - - def __init__(self, **kwargs: Any) -> None: - """init and validate environment variables""" - kwargs["sambastudio_url"] = get_from_dict_or_env( - kwargs, "sambastudio_url", "SAMBASTUDIO_URL" - ) - - kwargs["sambastudio_api_key"] = convert_to_secret_str( - get_from_dict_or_env(kwargs, "sambastudio_api_key", "SAMBASTUDIO_API_KEY") - ) - kwargs["base_url"], kwargs["streaming_url"] = self._get_sambastudio_urls( - kwargs["sambastudio_url"] - ) - super().__init__(**kwargs) - - def bind_tools( - self, - tools: Sequence[Union[Dict[str, Any], Type[Any], Callable[..., Any], BaseTool]], - *, - tool_choice: Optional[Union[Dict[str, Any], bool, str]] = None, - parallel_tool_calls: Optional[bool] = False, - **kwargs: Any, - ) -> Runnable[LanguageModelInput, AIMessage]: - """Bind tool-like objects to this chat model - - tool_choice: does not currently support "any", choice like - should be one of ["auto", "none", "required"] - """ - - formatted_tools = [convert_to_openai_tool(tool) for tool in tools] - - if tool_choice: - if isinstance(tool_choice, str): - # tool_choice is a tool/function name - if tool_choice not in ("auto", "none", "required"): - tool_choice = "auto" - elif isinstance(tool_choice, bool): - if tool_choice: - tool_choice = "required" - elif isinstance(tool_choice, dict): - raise ValueError( - "tool_choice must be one of ['auto', 'none', 'required']" - ) - else: - raise ValueError( - f"Unrecognized tool_choice type. Expected str, bool" - f"Received: {tool_choice}" - ) - else: - tool_choice = "auto" - kwargs["tool_choice"] = tool_choice - kwargs["parallel_tool_calls"] = parallel_tool_calls - return super().bind(tools=formatted_tools, **kwargs) - - def with_structured_output( - self, - schema: Optional[Union[Dict[str, Any], Type[BaseModel]]] = None, - *, - method: Literal[ - "function_calling", "json_mode", "json_schema" - ] = "function_calling", - include_raw: bool = False, - **kwargs: Any, - ) -> Runnable[LanguageModelInput, Union[Dict[str, Any], BaseModel]]: - """Model wrapper that returns outputs formatted to match the given schema. - - Args: - schema: - The output schema. Can be passed in as: - - an OpenAI function/tool schema, - - a JSON Schema, - - a TypedDict class, - - or a Pydantic class. - If `schema` is a Pydantic class then the model output will be a - Pydantic instance of that class, and the model-generated fields will be - validated by the Pydantic class. Otherwise the model output will be a - dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool` - for more on how to properly specify types and descriptions of - schema fields when specifying a Pydantic or TypedDict class. - - method: - The method for steering model generation, either "function_calling" - "json_mode" or "json_schema". - If "function_calling" then the schema will be converted - to an OpenAI function and the returned model will make use of the - function-calling API. If "json_mode" or "json_schema" then OpenAI's - JSON mode will be used. - Note that if using "json_mode" or "json_schema" then you must include instructions - for formatting the output into the desired schema into the model call. - - include_raw: - If False then only the parsed structured output is returned. If - an error occurs during model output parsing it will be raised. If True - then both the raw model response (a BaseMessage) and the parsed model - response will be returned. If an error occurs during output parsing it - will be caught and returned as well. The final output is always a dict - with keys "raw", "parsed", and "parsing_error". - - Returns: - A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. - - If `include_raw` is False and `schema` is a Pydantic class, Runnable outputs - an instance of `schema` (i.e., a Pydantic object). - - Otherwise, if `include_raw` is False then Runnable outputs a dict. - - If `include_raw` is True, then Runnable outputs a dict with keys: - - `"raw"`: BaseMessage - - `"parsed"`: None if there was a parsing error, otherwise the type depends on the `schema` as described above. - - `"parsing_error"`: Optional[BaseException] - - Example: schema=Pydantic class, method="function_calling", include_raw=False: - .. code-block:: python - - from typing import Optional - - from langchain_community.chat_models import ChatSambaStudio - from pydantic import BaseModel, Field - - - class AnswerWithJustification(BaseModel): - '''An answer to the user question along with justification for the answer.''' - - answer: str - justification: str = Field( - description="A justification for the answer." - ) - - - llm = ChatSambaStudio(model="Meta-Llama-3.1-70B-Instruct", temperature=0) - structured_llm = llm.with_structured_output(AnswerWithJustification) - - structured_llm.invoke( - "What weighs more a pound of bricks or a pound of feathers" - ) - - # -> AnswerWithJustification( - # answer='They weigh the same', - # justification='A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same.' - # ) - - Example: schema=Pydantic class, method="function_calling", include_raw=True: - .. code-block:: python - - from langchain_community.chat_models import ChatSambaStudio - from pydantic import BaseModel - - - class AnswerWithJustification(BaseModel): - '''An answer to the user question along with justification for the answer.''' - - answer: str - justification: str - - - llm = ChatSambaStudio(model="Meta-Llama-3.1-70B-Instruct", temperature=0) - structured_llm = llm.with_structured_output( - AnswerWithJustification, include_raw=True - ) - - structured_llm.invoke( - "What weighs more a pound of bricks or a pound of feathers" - ) - # -> { - # 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'function': {'arguments': '{"answer": "They weigh the same.", "justification": "A pound is a unit of weight or mass, so one pound of bricks and one pound of feathers both weigh the same amount."}', 'name': 'AnswerWithJustification'}, 'id': 'call_17a431fc6a4240e1bd', 'type': 'function'}]}, response_metadata={'finish_reason': 'tool_calls', 'usage': {'acceptance_rate': 5, 'completion_tokens': 53, 'completion_tokens_after_first_per_sec': 343.7964936837758, 'completion_tokens_after_first_per_sec_first_ten': 439.1205661878638, 'completion_tokens_per_sec': 162.8511306784833, 'end_time': 1731527851.0698032, 'is_last_response': True, 'prompt_tokens': 213, 'start_time': 1731527850.7137961, 'time_to_first_token': 0.20475482940673828, 'total_latency': 0.32545061111450196, 'total_tokens': 266, 'total_tokens_per_sec': 817.3283162354066}, 'model_name': 'Meta-Llama-3.1-70B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1731527850}, id='95667eaf-447f-4b53-bb6e-b6e1094ded88', tool_calls=[{'name': 'AnswerWithJustification', 'args': {'answer': 'They weigh the same.', 'justification': 'A pound is a unit of weight or mass, so one pound of bricks and one pound of feathers both weigh the same amount.'}, 'id': 'call_17a431fc6a4240e1bd', 'type': 'tool_call'}]), - # 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='A pound is a unit of weight or mass, so one pound of bricks and one pound of feathers both weigh the same amount.'), - # 'parsing_error': None - # } - - Example: schema=TypedDict class, method="function_calling", include_raw=False: - .. code-block:: python - - # IMPORTANT: If you are using Python <=3.8, you need to import Annotated - # from typing_extensions, not from typing. - from typing_extensions import Annotated, TypedDict - - from langchain_community.chat_models import ChatSambaStudio - - - class AnswerWithJustification(TypedDict): - '''An answer to the user question along with justification for the answer.''' - - answer: str - justification: Annotated[ - Optional[str], None, "A justification for the answer." - ] - - - llm = ChatSambaStudio(model="Meta-Llama-3.1-70B-Instruct", temperature=0) - structured_llm = llm.with_structured_output(AnswerWithJustification) - - structured_llm.invoke( - "What weighs more a pound of bricks or a pound of feathers" - ) - # -> { - # 'answer': 'They weigh the same', - # 'justification': 'A pound is a unit of weight or mass, so one pound of bricks and one pound of feathers both weigh the same amount.' - # } - - Example: schema=OpenAI function schema, method="function_calling", include_raw=False: - .. code-block:: python - - from langchain_community.chat_models import ChatSambaStudio - - oai_schema = { - 'name': 'AnswerWithJustification', - 'description': 'An answer to the user question along with justification for the answer.', - 'parameters': { - 'type': 'object', - 'properties': { - 'answer': {'type': 'string'}, - 'justification': {'description': 'A justification for the answer.', 'type': 'string'} - }, - 'required': ['answer'] - } - } - - llm = ChatSambaStudio(model="Meta-Llama-3.1-70B-Instruct", temperature=0) - structured_llm = llm.with_structured_output(oai_schema) - - structured_llm.invoke( - "What weighs more a pound of bricks or a pound of feathers" - ) - # -> { - # 'answer': 'They weigh the same', - # 'justification': 'A pound is a unit of weight or mass, so one pound of bricks and one pound of feathers both weigh the same amount.' - # } - - Example: schema=Pydantic class, method="json_mode", include_raw=True: - .. code-block:: - - from langchain_community.chat_models import ChatSambaStudio - from pydantic import BaseModel - - class AnswerWithJustification(BaseModel): - answer: str - justification: str - - llm = ChatSambaStudio(model="Meta-Llama-3.1-70B-Instruct", temperature=0) - structured_llm = llm.with_structured_output( - AnswerWithJustification, - method="json_mode", - include_raw=True - ) - - structured_llm.invoke( - "Answer the following question. " - "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" - "What's heavier a pound of bricks or a pound of feathers?" - ) - # -> { - # 'raw': AIMessage(content='{\n "answer": "They are the same weight",\n "justification": "A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities."\n}', additional_kwargs={}, response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 5.3125, 'completion_tokens': 79, 'completion_tokens_after_first_per_sec': 292.65701089829776, 'completion_tokens_after_first_per_sec_first_ten': 346.43324678555325, 'completion_tokens_per_sec': 200.012158915008, 'end_time': 1731528071.1708555, 'is_last_response': True, 'prompt_tokens': 70, 'start_time': 1731528070.737394, 'time_to_first_token': 0.16693782806396484, 'total_latency': 0.3949759876026827, 'total_tokens': 149, 'total_tokens_per_sec': 377.2381225105847}, 'model_name': 'Meta-Llama-3.1-70B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1731528070}, id='83208297-3eb9-4021-a856-ca78a15758df'), - # 'parsed': AnswerWithJustification(answer='They are the same weight', justification='A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities.'), - # 'parsing_error': None - # } - - Example: schema=None, method="json_mode", include_raw=True: - .. code-block:: - - from langchain_community.chat_models import ChatSambaStudio - - llm = ChatSambaStudio(model="Meta-Llama-3.1-70B-Instruct", temperature=0) - structured_llm = llm.with_structured_output(method="json_mode", include_raw=True) - - structured_llm.invoke( - "Answer the following question. " - "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" - "What's heavier a pound of bricks or a pound of feathers?" - ) - # -> { - # 'raw': AIMessage(content='{\n "answer": "They are the same weight",\n "justification": "A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities."\n}', additional_kwargs={}, response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 4.722222222222222, 'completion_tokens': 79, 'completion_tokens_after_first_per_sec': 357.1315485254867, 'completion_tokens_after_first_per_sec_first_ten': 416.83279609305305, 'completion_tokens_per_sec': 240.92819585198137, 'end_time': 1731528164.8474727, 'is_last_response': True, 'prompt_tokens': 70, 'start_time': 1731528164.4906917, 'time_to_first_token': 0.13837409019470215, 'total_latency': 0.3278985247892492, 'total_tokens': 149, 'total_tokens_per_sec': 454.4088757208256}, 'model_name': 'Meta-Llama-3.1-70B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1731528164}, id='15261eaf-8a25-42ef-8ed5-f63d8bf5b1b0'), - # 'parsed': { - # 'answer': 'They are the same weight', - # 'justification': 'A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities.'}, - # }, - # 'parsing_error': None - # } - - Example: schema=None, method="json_schema", include_raw=True: - .. code-block:: - - from langchain_community.chat_models import ChatSambaStudio - - class AnswerWithJustification(BaseModel): - answer: str - justification: str - - llm = ChatSambaStudio(model="Meta-Llama-3.1-70B-Instruct", temperature=0) - structured_llm = llm.with_structured_output(AnswerWithJustification, method="json_schema", include_raw=True) - - structured_llm.invoke( - "Answer the following question. " - "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" - "What's heavier a pound of bricks or a pound of feathers?" - ) - # -> { - # 'raw': AIMessage(content='{\n "answer": "They are the same weight",\n "justification": "A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities."\n}', additional_kwargs={}, response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 5.3125, 'completion_tokens': 79, 'completion_tokens_after_first_per_sec': 292.65701089829776, 'completion_tokens_after_first_per_sec_first_ten': 346.43324678555325, 'completion_tokens_per_sec': 200.012158915008, 'end_time': 1731528071.1708555, 'is_last_response': True, 'prompt_tokens': 70, 'start_time': 1731528070.737394, 'time_to_first_token': 0.16693782806396484, 'total_latency': 0.3949759876026827, 'total_tokens': 149, 'total_tokens_per_sec': 377.2381225105847}, 'model_name': 'Meta-Llama-3.1-70B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1731528070}, id='83208297-3eb9-4021-a856-ca78a15758df'), - # 'parsed': AnswerWithJustification(answer='They are the same weight', justification='A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities.'), - # 'parsing_error': None - # } - - """ # noqa: E501 - if kwargs: - raise ValueError(f"Received unsupported arguments {kwargs}") - is_pydantic_schema = _is_pydantic_class(schema) - if method == "function_calling": - if schema is None: - raise ValueError( - "schema must be specified when method is 'function_calling'. " - "Received None." - ) - tool_name = convert_to_openai_tool(schema)["function"]["name"] - llm = self.bind_tools([schema], tool_choice=tool_name) - if is_pydantic_schema: - output_parser: OutputParserLike[Any] = PydanticToolsParser( - tools=[schema], # type: ignore[list-item] - first_tool_only=True, - ) - else: - output_parser = JsonOutputKeyToolsParser( - key_name=tool_name, first_tool_only=True - ) - elif method == "json_mode": - llm = self - # TODO bind response format when json mode available by API - # llm = self.bind(response_format={"type": "json_object"}) - if is_pydantic_schema: - schema = cast(Type[BaseModel], schema) - output_parser = PydanticOutputParser(pydantic_object=schema) - else: - output_parser = JsonOutputParser() - - elif method == "json_schema": - if schema is None: - raise ValueError( - "schema must be specified when method is not 'json_mode'. " - "Received None." - ) - llm = self - # TODO bind response format when json schema available by API, - # update example - # llm = self.bind( - # response_format={"type": "json_object", "json_schema": schema} - # ) - if is_pydantic_schema: - schema = cast(Type[BaseModel], schema) - output_parser = PydanticOutputParser(pydantic_object=schema) - else: - output_parser = JsonOutputParser() - else: - raise ValueError( - f"Unrecognized method argument. Expected one of 'function_calling' or " - f"'json_mode'. Received: '{method}'" - ) - - if include_raw: - parser_assign = RunnablePassthrough.assign( - parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None - ) - parser_none = RunnablePassthrough.assign(parsed=lambda _: None) - parser_with_fallback = parser_assign.with_fallbacks( - [parser_none], exception_key="parsing_error" - ) - return RunnableMap(raw=llm) | parser_with_fallback - else: - return llm | output_parser - - def _get_role(self, message: BaseMessage) -> str: - """ - Get the role of LangChain BaseMessage - - Args: - message: LangChain BaseMessage - - Returns: - str: Role of the LangChain BaseMessage - """ - if isinstance(message, SystemMessage): - role = "system" - elif isinstance(message, HumanMessage): - role = "user" - elif isinstance(message, AIMessage): - role = "assistant" - elif isinstance(message, ToolMessage): - role = "tool" - elif isinstance(message, ChatMessage): - role = message.role - else: - raise TypeError(f"Got unknown type {message}") - return role - - def _messages_to_string(self, messages: List[BaseMessage], **kwargs: Any) -> str: - """ - Convert a list of BaseMessages to a: - - dumped json string with Role / content dict structure - when process_prompt is true, - - string with special tokens if process_prompt is false - for generic V1 and V2 endpoints - - Args: - messages: list of BaseMessages - - Returns: - str: string to send as model input depending on process_prompt param - """ - if self.process_prompt: - messages_dict: Dict[str, Any] = { - "conversation_id": "sambaverse-conversation-id", - "messages": [], - **kwargs, - } - for message in messages: - if isinstance(message, AIMessage): - message_dict = { - "message_id": message.id, - "role": self._get_role(message), - "content": message.content, - } - if "tool_calls" in message.additional_kwargs: - message_dict["tool_calls"] = message.additional_kwargs[ - "tool_calls" - ] - if message_dict["content"] == "": - message_dict["content"] = None - - elif isinstance(message, ToolMessage): - message_dict = { - "message_id": message.id, - "role": self._get_role(message), - "content": message.content, - "tool_call_id": message.tool_call_id, - } - - else: - message_dict = { - "message_id": message.id, - "role": self._get_role(message), - "content": message.content, - } - - messages_dict["messages"].append(message_dict) - - messages_string = json.dumps(messages_dict) - - else: - if "tools" in kwargs.keys(): - raise NotImplementedError( - "tool calling not supported in API Generic V2 " - "without process_prompt, switch to OpenAI compatible API " - "or Generic V2 API with process_prompt=True" - ) - messages_string = self.special_tokens["start"] - for message in messages: - messages_string += self.special_tokens["start_role"].format( - role=self._get_role(message) - ) - messages_string += f" {message.content} " - messages_string += self.special_tokens["end_role"] - messages_string += self.special_tokens["end"] - - return messages_string - - def _get_sambastudio_urls(self, url: str) -> Tuple[str, str]: - """ - Get streaming and non streaming URLs from the given URL - - Args: - url: string with sambastudio base or streaming endpoint url - - Returns: - base_url: string with url to do non streaming calls - streaming_url: string with url to do streaming calls - """ - if "chat/completions" in url: - base_url = url - stream_url = url - else: - if "stream" in url: - base_url = url.replace("stream/", "") - stream_url = url - else: - base_url = url - if "generic" in url: - stream_url = "generic/stream".join(url.split("generic")) - else: - raise ValueError("Unsupported URL") - return base_url, stream_url - - def _handle_request( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - streaming: Optional[bool] = False, - **kwargs: Any, - ) -> Response: - """ - Performs a post request to the LLM API. - - Args: - messages_dicts: List of role / content dicts to use as input. - stop: list of stop tokens - streaming: wether to do a streaming call - - Returns: - A request Response object - """ - - # create request payload for openai compatible API - if "chat/completions" in self.sambastudio_url: - messages_dicts = _create_message_dicts(messages) - data = { - "messages": messages_dicts, - "max_tokens": self.max_tokens, - "stop": stop, - "model": self.model, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "stream": streaming, - "stream_options": self.stream_options, - **kwargs, - } - data = {key: value for key, value in data.items() if value is not None} - headers = { - "Authorization": f"Bearer " - f"{self.sambastudio_api_key.get_secret_value()}", - "Content-Type": "application/json", - **self.additional_headers, - } - - # create request payload for generic v2 API - elif "api/v2/predict/generic" in self.sambastudio_url: - items = [ - {"id": "item0", "value": self._messages_to_string(messages, **kwargs)} - ] - params: Dict[str, Any] = { - "select_expert": self.model, - "process_prompt": self.process_prompt, - "max_tokens_to_generate": self.max_tokens, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "do_sample": self.do_sample, - } - if self.model_kwargs is not None: - params = {**params, **self.model_kwargs} - params = {key: value for key, value in params.items() if value is not None} - data = {"items": items, "params": params} - headers = { - "key": self.sambastudio_api_key.get_secret_value(), - **self.additional_headers, - } - - # create request payload for generic v1 API - elif "api/predict/generic" in self.sambastudio_url: - if "tools" in kwargs.keys(): - raise NotImplementedError( - "tool calling not supported in API Generic V1, " - "switch to OpenAI compatible API or Generic V2 API" - ) - params = { - "select_expert": self.model, - "process_prompt": self.process_prompt, - "max_tokens_to_generate": self.max_tokens, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "do_sample": self.do_sample, - **kwargs, - } - if self.model_kwargs is not None: - params = {**params, **self.model_kwargs} - params = { - key: {"type": type(value).__name__, "value": str(value)} - for key, value in params.items() - if value is not None - } - if streaming: - data = { - "instance": self._messages_to_string(messages), - "params": params, - } - else: - data = { - "instances": [self._messages_to_string(messages)], - "params": params, - } - headers = { - "key": self.sambastudio_api_key.get_secret_value(), - **self.additional_headers, - } - - else: - raise ValueError( - f"Unsupported URL{self.sambastudio_url}" - "only openai, generic v1 and generic v2 APIs are supported" - ) - - http_session = requests.Session() - if streaming: - response = http_session.post( - self.streaming_url, headers=headers, json=data, stream=True - ) - else: - response = http_session.post( - self.base_url, headers=headers, json=data, stream=False - ) - if response.status_code != 200: - raise RuntimeError( - f"Sambanova /complete call failed with status code " - f"{response.status_code}." - f"{response.text}." - ) - return response - - def _process_response(self, response: Response) -> AIMessage: - """ - Process a non streaming response from the api - - Args: - response: A request Response object - - Returns - generation: an AIMessage with model generation - """ - - # Extract json payload form response - try: - response_dict = response.json() - except Exception as e: - raise RuntimeError( - f"Sambanova /complete call failed couldn't get JSON response {e}" - f"response: {response.text}" - ) - - additional_kwargs: Dict[str, Any] = {} - tool_calls = [] - invalid_tool_calls = [] - - # process response payload for openai compatible API - if "chat/completions" in self.sambastudio_url: - content = response_dict["choices"][0]["message"].get("content", "") - if content is None: - content = "" - id = response_dict["id"] - response_metadata = { - "finish_reason": response_dict["choices"][0]["finish_reason"], - "usage": response_dict.get("usage"), - "model_name": response_dict["model"], - "system_fingerprint": response_dict["system_fingerprint"], - "created": response_dict["created"], - } - raw_tool_calls = response_dict["choices"][0]["message"].get("tool_calls") - if raw_tool_calls: - additional_kwargs["tool_calls"] = raw_tool_calls - for raw_tool_call in raw_tool_calls: - if isinstance(raw_tool_call["function"]["arguments"], dict): - raw_tool_call["function"]["arguments"] = json.dumps( - raw_tool_call["function"].get("arguments", {}) - ) - try: - tool_calls.append( - parse_tool_call(raw_tool_call, return_id=True) - ) - except Exception as e: - invalid_tool_calls.append( - make_invalid_tool_call(raw_tool_call, str(e)) - ) - - # process response payload for generic v2 API - elif "api/v2/predict/generic" in self.sambastudio_url: - content = response_dict["items"][0]["value"]["completion"] - id = response_dict["items"][0]["id"] - response_metadata = response_dict["items"][0] - raw_tool_calls = response_dict["items"][0]["value"].get("tool_calls") - if raw_tool_calls: - additional_kwargs["tool_calls"] = raw_tool_calls - for raw_tool_call in raw_tool_calls: - if isinstance(raw_tool_call["function"]["arguments"], dict): - raw_tool_call["function"]["arguments"] = json.dumps( - raw_tool_call["function"].get("arguments", {}) - ) - try: - tool_calls.append( - parse_tool_call(raw_tool_call, return_id=True) - ) - except Exception as e: - invalid_tool_calls.append( - make_invalid_tool_call(raw_tool_call, str(e)) - ) - - # process response payload for generic v1 API - elif "api/predict/generic" in self.sambastudio_url: - content = response_dict["predictions"][0]["completion"] - id = None - response_metadata = response_dict - - else: - raise ValueError( - f"Unsupported URL{self.sambastudio_url}" - "only openai, generic v1 and generic v2 APIs are supported" - ) - - return AIMessage( - content=content, - additional_kwargs=additional_kwargs, - tool_calls=tool_calls, - invalid_tool_calls=invalid_tool_calls, - response_metadata=response_metadata, - id=id, - ) - - def _process_stream_response( - self, response: Response - ) -> Iterator[BaseMessageChunk]: - """ - Process a streaming response from the api - - Args: - response: An iterable request Response object - - Yields: - generation: an AIMessageChunk with model partial generation - """ - - try: - import sseclient - except ImportError: - raise ImportError( - "could not import sseclient library" - "Please install it with `pip install sseclient-py`." - ) - - # process response payload for openai compatible API - if "chat/completions" in self.sambastudio_url: - finish_reason = "" - client = sseclient.SSEClient(response) - for event in client.events(): - if event.event == "error_event": - raise RuntimeError( - f"Sambanova /complete call failed with status code " - f"{response.status_code}." - f"{event.data}." - ) - try: - # check if the response is not a final event ("[DONE]") - if event.data != "[DONE]": - if isinstance(event.data, str): - data = json.loads(event.data) - else: - raise RuntimeError( - f"Sambanova /complete call failed with status code " - f"{response.status_code}." - f"{event.data}." - ) - if data.get("error"): - raise RuntimeError( - f"Sambanova /complete call failed with status code " - f"{response.status_code}." - f"{event.data}." - ) - if len(data["choices"]) > 0: - finish_reason = data["choices"][0].get("finish_reason") - content = data["choices"][0]["delta"]["content"] - id = data["id"] - metadata = {} - else: - content = "" - id = data["id"] - metadata = { - "finish_reason": finish_reason, - "usage": data.get("usage"), - "model_name": data["model"], - "system_fingerprint": data["system_fingerprint"], - "created": data["created"], - } - if data.get("usage") is not None: - content = "" - id = data["id"] - metadata = { - "finish_reason": finish_reason, - "usage": data.get("usage"), - "model_name": data["model"], - "system_fingerprint": data["system_fingerprint"], - "created": data["created"], - } - yield AIMessageChunk( - content=content, - id=id, - response_metadata=metadata, - additional_kwargs={}, - ) - - except Exception as e: - raise RuntimeError( - f"Error getting content chunk raw streamed response: {e}" - f"data: {event.data}" - ) - - # process response payload for generic v2 API - elif "api/v2/predict/generic" in self.sambastudio_url: - for line in response.iter_lines(): - try: - data = json.loads(line) - content = data["result"]["items"][0]["value"]["stream_token"] - id = data["result"]["items"][0]["id"] - if data["result"]["items"][0]["value"]["is_last_response"]: - metadata = { - "finish_reason": data["result"]["items"][0]["value"].get( - "stop_reason" - ), - "prompt": data["result"]["items"][0]["value"].get("prompt"), - "usage": { - "prompt_tokens_count": data["result"]["items"][0][ - "value" - ].get("prompt_tokens_count"), - "completion_tokens_count": data["result"]["items"][0][ - "value" - ].get("completion_tokens_count"), - "total_tokens_count": data["result"]["items"][0][ - "value" - ].get("total_tokens_count"), - "start_time": data["result"]["items"][0]["value"].get( - "start_time" - ), - "end_time": data["result"]["items"][0]["value"].get( - "end_time" - ), - "model_execution_time": data["result"]["items"][0][ - "value" - ].get("model_execution_time"), - "time_to_first_token": data["result"]["items"][0][ - "value" - ].get("time_to_first_token"), - "throughput_after_first_token": data["result"]["items"][ - 0 - ]["value"].get("throughput_after_first_token"), - "batch_size_used": data["result"]["items"][0][ - "value" - ].get("batch_size_used"), - }, - } - else: - metadata = {} - yield AIMessageChunk( - content=content, - id=id, - response_metadata=metadata, - additional_kwargs={}, - ) - - except Exception as e: - raise RuntimeError( - f"Error getting content chunk raw streamed response: {e}" - f"line: {line}" - ) - - # process response payload for generic v1 API - elif "api/predict/generic" in self.sambastudio_url: - for line in response.iter_lines(): - try: - data = json.loads(line) - content = data["result"]["responses"][0]["stream_token"] - id = None - if data["result"]["responses"][0]["is_last_response"]: - metadata = { - "finish_reason": data["result"]["responses"][0].get( - "stop_reason" - ), - "prompt": data["result"]["responses"][0].get("prompt"), - "usage": { - "prompt_tokens_count": data["result"]["responses"][ - 0 - ].get("prompt_tokens_count"), - "completion_tokens_count": data["result"]["responses"][ - 0 - ].get("completion_tokens_count"), - "total_tokens_count": data["result"]["responses"][ - 0 - ].get("total_tokens_count"), - "start_time": data["result"]["responses"][0].get( - "start_time" - ), - "end_time": data["result"]["responses"][0].get( - "end_time" - ), - "model_execution_time": data["result"]["responses"][ - 0 - ].get("model_execution_time"), - "time_to_first_token": data["result"]["responses"][ - 0 - ].get("time_to_first_token"), - "throughput_after_first_token": data["result"][ - "responses" - ][0].get("throughput_after_first_token"), - "batch_size_used": data["result"]["responses"][0].get( - "batch_size_used" - ), - }, - } - else: - metadata = {} - yield AIMessageChunk( - content=content, - id=id, - response_metadata=metadata, - additional_kwargs={}, - ) - - except Exception as e: - raise RuntimeError( - f"Error getting content chunk raw streamed response: {e}" - f"line: {line}" - ) - - else: - raise ValueError( - f"Unsupported URL{self.sambastudio_url}" - "only openai, generic v1 and generic v2 APIs are supported" - ) - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - """ - Call SambaStudio models. - - Args: - messages: the prompt composed of a list of messages. - stop: a list of strings on which the model should stop generating. - If generation stops due to a stop token, the stop token itself - SHOULD BE INCLUDED as part of the output. This is not enforced - across models right now, but it's a good practice to follow since - it makes it much easier to parse the output of the model - downstream and understand why generation stopped. - run_manager: A run manager with callbacks for the LLM. - - Returns: - result: ChatResult with model generation - """ - if self.streaming: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - if stream_iter: - return generate_from_stream(stream_iter) - response = self._handle_request(messages, stop, streaming=False, **kwargs) - message = self._process_response(response) - generation = ChatGeneration(message=message) - return ChatResult(generations=[generation]) - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - """ - Stream the output of the SambaStudio model. - - Args: - messages: the prompt composed of a list of messages. - stop: a list of strings on which the model should stop generating. - If generation stops due to a stop token, the stop token itself - SHOULD BE INCLUDED as part of the output. This is not enforced - across models right now, but it's a good practice to follow since - it makes it much easier to parse the output of the model - downstream and understand why generation stopped. - run_manager: A run manager with callbacks for the LLM. - - Yields: - chunk: ChatGenerationChunk with model partial generation - """ - response = self._handle_request(messages, stop, streaming=True, **kwargs) - for ai_message_chunk in self._process_stream_response(response): - chunk = ChatGenerationChunk(message=ai_message_chunk) - if run_manager: - run_manager.on_llm_new_token(chunk.text, chunk=chunk) - yield chunk diff --git a/libs/community/langchain_community/chat_models/solar.py b/libs/community/langchain_community/chat_models/solar.py deleted file mode 100644 index 2be70ddc1..000000000 --- a/libs/community/langchain_community/chat_models/solar.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Wrapper around Solar chat models.""" - -from typing import Dict - -from langchain_core._api import deprecated -from langchain_core.utils import get_from_dict_or_env, pre_init -from pydantic import ConfigDict, Field - -from langchain_community.chat_models import ChatOpenAI -from langchain_community.llms.solar import SOLAR_SERVICE_URL_BASE, SolarCommon - - -@deprecated( - since="0.0.34", removal="1.0", alternative_import="langchain_upstage.ChatUpstage" -) -class SolarChat(SolarCommon, ChatOpenAI): - """Wrapper around Solar large language models. - To use, you should have the ``openai`` python package installed, and the - environment variable ``SOLAR_API_KEY`` set with your API key. - (Solar's chat API is compatible with OpenAI's SDK.) - Referenced from https://console.upstage.ai/services/solar - Example: - .. code-block:: python - - from langchain_community.chat_models.solar import SolarChat - - solar = SolarChat(model="solar-mini") - """ - - max_tokens: int = Field(default=1024) - - # this is needed to match ChatOpenAI superclass - model_config = ConfigDict( - populate_by_name=True, - arbitrary_types_allowed=True, - extra="ignore", - ) - - @pre_init - def validate_environment(cls, values: Dict) -> Dict: - """Validate that the environment is set up correctly.""" - values["solar_api_key"] = get_from_dict_or_env( - values, "solar_api_key", "SOLAR_API_KEY" - ) - - try: - import openai - - except ImportError: - raise ImportError( - "Could not import openai python package. " - "Please install it with `pip install openai`." - ) - - client_params = { - "api_key": values["solar_api_key"], - "base_url": ( - values["base_url"] if "base_url" in values else SOLAR_SERVICE_URL_BASE - ), - } - - if not values.get("client"): - values["client"] = openai.OpenAI(**client_params).chat.completions - if not values.get("async_client"): - values["async_client"] = openai.AsyncOpenAI( - **client_params - ).chat.completions - - return values diff --git a/libs/community/langchain_community/chat_models/vertexai.py b/libs/community/langchain_community/chat_models/vertexai.py deleted file mode 100644 index eeb638e2d..000000000 --- a/libs/community/langchain_community/chat_models/vertexai.py +++ /dev/null @@ -1,393 +0,0 @@ -"""Wrapper around Google VertexAI chat-based models.""" - -from __future__ import annotations - -import base64 -import logging -import re -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union, cast -from urllib.parse import urlparse - -import requests -from langchain_core._api.deprecation import deprecated -from langchain_core.callbacks import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain_core.language_models.chat_models import ( - BaseChatModel, - generate_from_stream, -) -from langchain_core.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - HumanMessage, - SystemMessage, -) -from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.utils import pre_init - -from langchain_community.llms.vertexai import ( - _VertexAICommon, - is_codey_model, - is_gemini_model, -) -from langchain_community.utilities.vertexai import ( - load_image_from_gcs, - raise_vertex_import_error, -) - -if TYPE_CHECKING: - from vertexai.language_models import ( - ChatMessage, - ChatSession, - CodeChatSession, - InputOutputTextPair, - ) - from vertexai.preview.generative_models import Content - -logger = logging.getLogger(__name__) - - -@dataclass -class _ChatHistory: - """Represents a context and a history of messages.""" - - history: List["ChatMessage"] = field(default_factory=list) - context: Optional[str] = None - - -def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory: - """Parse a sequence of messages into history. - - Args: - history: The list of messages to re-create the history of the chat. - Returns: - A parsed chat history. - Raises: - ValueError: If a sequence of message has a SystemMessage not at the - first place. - """ - from vertexai.language_models import ChatMessage - - vertex_messages, context = [], None - for i, message in enumerate(history): - content = cast(str, message.content) - if i == 0 and isinstance(message, SystemMessage): - context = content - elif isinstance(message, AIMessage): - vertex_message = ChatMessage(content=message.content, author="bot") - vertex_messages.append(vertex_message) - elif isinstance(message, HumanMessage): - vertex_message = ChatMessage(content=message.content, author="user") - vertex_messages.append(vertex_message) - else: - raise ValueError( - f"Unexpected message with type {type(message)} at the position {i}." - ) - chat_history = _ChatHistory(context=context, history=vertex_messages) - return chat_history - - -def _is_url(s: str) -> bool: - try: - result = urlparse(s) - return all([result.scheme, result.netloc]) - except Exception as e: - logger.debug(f"Unable to parse URL: {e}") - return False - - -def _parse_chat_history_gemini( - history: List[BaseMessage], project: Optional[str] -) -> List["Content"]: - from vertexai.preview.generative_models import Content, Image, Part - - def _convert_to_prompt(part: Union[str, Dict]) -> Part: - if isinstance(part, str): - return Part.from_text(part) - - if not isinstance(part, Dict): - raise ValueError( - f"Message's content is expected to be a dict, got {type(part)}!" - ) - if part["type"] == "text": - return Part.from_text(part["text"]) - elif part["type"] == "image_url": - path = part["image_url"]["url"] - if path.startswith("gs://"): - image = load_image_from_gcs(path=path, project=project) - elif path.startswith("data:image/"): - # extract base64 component from image uri - encoded: Any = re.search(r"data:image/\w{2,4};base64,(.*)", path) - if encoded: - encoded = encoded.group(1) - else: - raise ValueError( - "Invalid image uri. It should be in the format " - "data:image/;base64,." - ) - image = Image.from_bytes(base64.b64decode(encoded)) - elif _is_url(path): - response = requests.get(path) - response.raise_for_status() - image = Image.from_bytes(response.content) - else: - image = Image.load_from_file(path) - else: - raise ValueError("Only text and image_url types are supported!") - return Part.from_image(image) - - vertex_messages = [] - for i, message in enumerate(history): - if i == 0 and isinstance(message, SystemMessage): - raise ValueError("SystemMessages are not yet supported!") - elif isinstance(message, AIMessage): - role = "model" - elif isinstance(message, HumanMessage): - role = "user" - else: - raise ValueError( - f"Unexpected message with type {type(message)} at the position {i}." - ) - - raw_content = message.content - if isinstance(raw_content, str): - raw_content = [raw_content] - parts = [_convert_to_prompt(part) for part in raw_content] - vertex_message = Content(role=role, parts=parts) - vertex_messages.append(vertex_message) - return vertex_messages - - -def _parse_examples(examples: List[BaseMessage]) -> List["InputOutputTextPair"]: - from vertexai.language_models import InputOutputTextPair - - if len(examples) % 2 != 0: - raise ValueError( - f"Expect examples to have an even amount of messages, got {len(examples)}." - ) - example_pairs = [] - input_text = None - for i, example in enumerate(examples): - if i % 2 == 0: - if not isinstance(example, HumanMessage): - raise ValueError( - f"Expected the first message in a part to be from human, got " - f"{type(example)} for the {i}th message." - ) - input_text = example.content - if i % 2 == 1: - if not isinstance(example, AIMessage): - raise ValueError( - f"Expected the second message in a part to be from AI, got " - f"{type(example)} for the {i}th message." - ) - pair = InputOutputTextPair( - input_text=input_text, output_text=example.content - ) - example_pairs.append(pair) - return example_pairs - - -def _get_question(messages: List[BaseMessage]) -> HumanMessage: - """Get the human message at the end of a list of input messages to a chat model.""" - if not messages: - raise ValueError("You should provide at least one message to start the chat!") - question = messages[-1] - if not isinstance(question, HumanMessage): - raise ValueError( - f"Last message in the list should be from human, got {question.type}." - ) - return question - - -@deprecated( - since="0.0.12", - removal="1.0", - alternative_import="langchain_google_vertexai.ChatVertexAI", -) -class ChatVertexAI(_VertexAICommon, BaseChatModel): - """`Vertex AI` Chat large language models API.""" - - model_name: str = "chat-bison" - "Underlying model name." - examples: Optional[List[BaseMessage]] = None - - @classmethod - def is_lc_serializable(self) -> bool: - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" - return ["langchain", "chat_models", "vertexai"] - - @pre_init - def validate_environment(cls, values: Dict) -> Dict: - """Validate that the python package exists in environment.""" - is_gemini = is_gemini_model(values["model_name"]) - cls._try_init_vertexai(values) - try: - from vertexai.language_models import ChatModel, CodeChatModel - - if is_gemini: - from vertexai.preview.generative_models import ( - GenerativeModel, - ) - except ImportError: - raise_vertex_import_error() - if is_gemini: - values["client"] = GenerativeModel(model_name=values["model_name"]) - else: - if is_codey_model(values["model_name"]): - model_cls = CodeChatModel - else: - model_cls = ChatModel - values["client"] = model_cls.from_pretrained(values["model_name"]) - return values - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, - **kwargs: Any, - ) -> ChatResult: - """Generate next turn in the conversation. - - Args: - messages: The history of the conversation as a list of messages. Code chat - does not support context. - stop: The list of stop words (optional). - run_manager: The CallbackManager for LLM run, it's not used at the moment. - stream: Whether to use the streaming endpoint. - - Returns: - The ChatResult that contains outputs generated by the model. - - Raises: - ValueError: if the last message in the list is not from human. - """ - should_stream = stream if stream is not None else self.streaming - if should_stream: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return generate_from_stream(stream_iter) - - question = _get_question(messages) - params = self._prepare_params(stop=stop, stream=False, **kwargs) - msg_params = {} - if "candidate_count" in params: - msg_params["candidate_count"] = params.pop("candidate_count") - - if self._is_gemini_model: - history_gemini = _parse_chat_history_gemini(messages, project=self.project) - message = history_gemini.pop() - chat = self.client.start_chat(history=history_gemini) - response = chat.send_message(message, generation_config=params) - else: - history = _parse_chat_history(messages[:-1]) - examples = kwargs.get("examples") or self.examples - if examples: - params["examples"] = _parse_examples(examples) - chat = self._start_chat(history, **params) - response = chat.send_message(question.content, **msg_params) - generations = [ - ChatGeneration(message=AIMessage(content=r.text)) - for r in response.candidates - ] - return ChatResult(generations=generations) - - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - """Asynchronously generate next turn in the conversation. - - Args: - messages: The history of the conversation as a list of messages. Code chat - does not support context. - stop: The list of stop words (optional). - run_manager: The CallbackManager for LLM run, it's not used at the moment. - - Returns: - The ChatResult that contains outputs generated by the model. - - Raises: - ValueError: if the last message in the list is not from human. - """ - if "stream" in kwargs: - kwargs.pop("stream") - logger.warning("ChatVertexAI does not currently support async streaming.") - - params = self._prepare_params(stop=stop, **kwargs) - msg_params = {} - if "candidate_count" in params: - msg_params["candidate_count"] = params.pop("candidate_count") - - if self._is_gemini_model: - history_gemini = _parse_chat_history_gemini(messages, project=self.project) - message = history_gemini.pop() - chat = self.client.start_chat(history=history_gemini) - response = await chat.send_message_async(message, generation_config=params) - else: - question = _get_question(messages) - history = _parse_chat_history(messages[:-1]) - examples = kwargs.get("examples", None) - if examples: - params["examples"] = _parse_examples(examples) - chat = self._start_chat(history, **params) - response = await chat.send_message_async(question.content, **msg_params) - - generations = [ - ChatGeneration(message=AIMessage(content=r.text)) - for r in response.candidates - ] - return ChatResult(generations=generations) - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - params = self._prepare_params(stop=stop, stream=True, **kwargs) - if self._is_gemini_model: - history_gemini = _parse_chat_history_gemini(messages, project=self.project) - message = history_gemini.pop() - chat = self.client.start_chat(history=history_gemini) - responses = chat.send_message( - message, stream=True, generation_config=params - ) - else: - question = _get_question(messages) - history = _parse_chat_history(messages[:-1]) - examples = kwargs.get("examples", None) - if examples: - params["examples"] = _parse_examples(examples) - chat = self._start_chat(history, **params) - responses = chat.send_message_streaming(question.content, **params) - for response in responses: - chunk = ChatGenerationChunk(message=AIMessageChunk(content=response.text)) - if run_manager: - run_manager.on_llm_new_token(response.text, chunk=chunk) - yield chunk - - def _start_chat( - self, history: _ChatHistory, **kwargs: Any - ) -> Union[ChatSession, CodeChatSession]: - if not self.is_codey_model: - return self.client.start_chat( - context=history.context, message_history=history.history, **kwargs - ) - else: - return self.client.start_chat(message_history=history.history, **kwargs) diff --git a/libs/community/langchain_community/document_transformers/openai_functions.py b/libs/community/langchain_community/document_transformers/openai_functions.py index 88b57f20e..daa2d1ce5 100644 --- a/libs/community/langchain_community/document_transformers/openai_functions.py +++ b/libs/community/langchain_community/document_transformers/openai_functions.py @@ -14,7 +14,7 @@ class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel): Example: .. code-block:: python - from langchain_community.chat_models import ChatOpenAI + from langchain_openai import ChatOpenAI from langchain_community.document_transformers import OpenAIMetadataTagger from langchain_core.documents import Document @@ -97,7 +97,7 @@ def create_metadata_tagger( Example: .. code-block:: python - from langchain_community.chat_models import ChatOpenAI + from langchain_openai import ChatOpenAI from langchain_community.document_transformers import create_metadata_tagger from langchain_core.documents import Document diff --git a/libs/community/langchain_community/example_selectors/__init__.py b/libs/community/langchain_community/example_selectors/__init__.py index d29bf7372..edadfd195 100644 --- a/libs/community/langchain_community/example_selectors/__init__.py +++ b/libs/community/langchain_community/example_selectors/__init__.py @@ -1,10 +1,13 @@ -"""**Example selector** implements logic for selecting examples to include them -in prompts. +"""**Example selector** implements logic for selecting examples to include them in +prompts. + This allows us to select examples that are most relevant to the input. -There could be multiple strategies for selecting examples. For example, one could -select examples based on the similarity of the input to the examples. Another -strategy could be to select examples based on the diversity of the examples. +There could be multiple strategies for selecting examples. + +For example, one could select examples based on the similarity of the input to the +examples. Another strategy could be to select examples based on the diversity of the +examples. """ from langchain_community.example_selectors.ngram_overlap import ( diff --git a/libs/community/langchain_community/example_selectors/ngram_overlap.py b/libs/community/langchain_community/example_selectors/ngram_overlap.py index 92577acd5..0cbd6e71e 100644 --- a/libs/community/langchain_community/example_selectors/ngram_overlap.py +++ b/libs/community/langchain_community/example_selectors/ngram_overlap.py @@ -1,4 +1,4 @@ -"""Select and order examples based on ngram overlap score (sentence_bleu score). +"""Select and order examples based on ngram overlap score (`sentence_bleu` score). https://www.nltk.org/_modules/nltk/translate/bleu_score.html https://aclanthology.org/P02-1040.pdf @@ -13,11 +13,13 @@ def ngram_overlap_score(source: List[str], example: List[str]) -> float: - """Compute ngram overlap score of source and example as sentence_bleu score + """Compute ngram overlap score of source and example as `sentence_bleu` score from NLTK package. - Use sentence_bleu with method1 smoothing function and auto reweighting. + Use `sentence_bleu` with `method1` smoothing function and auto reweighting. + Return float value between 0.0 and 1.0 inclusive. + https://www.nltk.org/_modules/nltk/translate/bleu_score.html https://aclanthology.org/P02-1040.pdf """ @@ -40,7 +42,7 @@ def ngram_overlap_score(source: List[str], example: List[str]) -> float: class NGramOverlapExampleSelector(BaseExampleSelector, BaseModel): - """Select and order examples based on ngram overlap score (sentence_bleu score + """Select and order examples based on ngram overlap score (`sentence_bleu` score from NLTK package). https://www.nltk.org/_modules/nltk/translate/bleu_score.html diff --git a/libs/community/langchain_community/llms/__init__.py b/libs/community/langchain_community/llms/__init__.py index 45e005242..bcf77c13f 100644 --- a/libs/community/langchain_community/llms/__init__.py +++ b/libs/community/langchain_community/llms/__init__.py @@ -168,18 +168,6 @@ def _import_databricks() -> Type[BaseLLM]: return Databricks -# deprecated / only for back compat - do not add to __all__ -def _import_databricks_chat() -> Any: - warn_deprecated( - since="0.0.22", - removal="1.0", - alternative_import="langchain_community.chat_models.ChatDatabricks", - ) - from langchain_community.chat_models.databricks import ChatDatabricks - - return ChatDatabricks - - def _import_deepinfra() -> Type[BaseLLM]: from langchain_community.llms.deepinfra import DeepInfra @@ -1024,7 +1012,6 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]: "ctransformers": _import_ctransformers, "ctranslate2": _import_ctranslate2, "databricks": _import_databricks, - "databricks-chat": _import_databricks_chat, # deprecated / only for back compat "deepinfra": _import_deepinfra, "deepsparse": _import_deepsparse, "edenai": _import_edenai, diff --git a/libs/community/langchain_community/llms/opaqueprompts.py b/libs/community/langchain_community/llms/opaqueprompts.py index 46a2a2b36..b304f917d 100644 --- a/libs/community/langchain_community/llms/opaqueprompts.py +++ b/libs/community/langchain_community/llms/opaqueprompts.py @@ -25,7 +25,7 @@ class OpaquePrompts(LLM): .. code-block:: python from langchain_community.llms import OpaquePrompts - from langchain_community.chat_models import ChatOpenAI + from langchain_openai import ChatOpenAI op_llm = OpaquePrompts(base_llm=ChatOpenAI()) """ diff --git a/libs/community/langchain_community/retrievers/__init__.py b/libs/community/langchain_community/retrievers/__init__.py index ce4ac731b..277af1249 100644 --- a/libs/community/langchain_community/retrievers/__init__.py +++ b/libs/community/langchain_community/retrievers/__init__.py @@ -103,9 +103,6 @@ from langchain_community.retrievers.pubmed import ( PubMedRetriever, ) - from langchain_community.retrievers.qdrant_sparse_vector_retriever import ( - QdrantSparseVectorRetriever, - ) from langchain_community.retrievers.rememberizer import ( RememberizerRetriever, ) @@ -178,7 +175,6 @@ "OutlineRetriever": "langchain_community.retrievers.outline", "PineconeHybridSearchRetriever": "langchain_community.retrievers.pinecone_hybrid_search", # noqa: E501 "PubMedRetriever": "langchain_community.retrievers.pubmed", - "QdrantSparseVectorRetriever": "langchain_community.retrievers.qdrant_sparse_vector_retriever", # noqa: E501 "RememberizerRetriever": "langchain_community.retrievers.rememberizer", "RemoteLangChainRetriever": "langchain_community.retrievers.remote_retriever", "SVMRetriever": "langchain_community.retrievers.svm", @@ -236,7 +232,6 @@ def __getattr__(name: str) -> Any: "OutlineRetriever", "PineconeHybridSearchRetriever", "PubMedRetriever", - "QdrantSparseVectorRetriever", "RememberizerRetriever", "RemoteLangChainRetriever", "SVMRetriever", diff --git a/libs/community/langchain_community/retrievers/qdrant_sparse_vector_retriever.py b/libs/community/langchain_community/retrievers/qdrant_sparse_vector_retriever.py deleted file mode 100644 index 1b64c3467..000000000 --- a/libs/community/langchain_community/retrievers/qdrant_sparse_vector_retriever.py +++ /dev/null @@ -1,220 +0,0 @@ -import uuid -from itertools import islice -from typing import ( - Any, - Callable, - Dict, - Generator, - Iterable, - List, - Optional, - Sequence, - Tuple, - cast, -) - -from langchain_core._api.deprecation import deprecated -from langchain_core.callbacks import CallbackManagerForRetrieverRun -from langchain_core.documents import Document -from langchain_core.retrievers import BaseRetriever -from langchain_core.utils import pre_init -from pydantic import ConfigDict - -from langchain_community.vectorstores.qdrant import Qdrant, QdrantException - - -@deprecated( - since="0.2.16", - alternative=( - "Qdrant vector store now supports sparse retrievals natively. " - "Use langchain_qdrant.QdrantVectorStore#as_retriever() instead. " - "Reference: " - "https://python.langchain.com/docs/integrations/vectorstores/qdrant/#sparse-vector-search" - ), - removal="0.5.0", -) -class QdrantSparseVectorRetriever(BaseRetriever): - """Qdrant sparse vector retriever.""" - - client: Any = None - """'qdrant_client' instance to use.""" - collection_name: str - """Qdrant collection name.""" - sparse_vector_name: str - """Name of the sparse vector to use.""" - sparse_encoder: Callable[[str], Tuple[List[int], List[float]]] - """Sparse encoder function to use.""" - k: int = 4 - """Number of documents to return per query. Defaults to 4.""" - filter: Optional[Any] = None - """Qdrant qdrant_client.models.Filter to use for queries. Defaults to None.""" - content_payload_key: str = "content" - """Payload field containing the document content. Defaults to 'content'""" - metadata_payload_key: str = "metadata" - """Payload field containing the document metadata. Defaults to 'metadata'.""" - search_options: Dict[str, Any] = {} - """Additional search options to pass to qdrant_client.QdrantClient.search().""" - - model_config = ConfigDict( - arbitrary_types_allowed=True, - extra="forbid", - ) - - @pre_init - def validate_environment(cls, values: Dict) -> Dict: - """Validate that 'qdrant_client' python package exists in environment.""" - try: - from grpc import RpcError - from qdrant_client import QdrantClient, models - from qdrant_client.http.exceptions import UnexpectedResponse - except ImportError: - raise ImportError( - "Could not import qdrant-client python package. " - "Please install it with `pip install qdrant-client`." - ) - - client = values["client"] - if not isinstance(client, QdrantClient): - raise ValueError( - f"client should be an instance of qdrant_client.QdrantClient, " - f"got {type(client)}" - ) - - filter = values["filter"] - if filter is not None and not isinstance(filter, models.Filter): - raise ValueError( - f"filter should be an instance of qdrant_client.models.Filter, " - f"got {type(filter)}" - ) - - client = cast(QdrantClient, client) - - collection_name = values["collection_name"] - sparse_vector_name = values["sparse_vector_name"] - - try: - collection_info = client.get_collection(collection_name) - sparse_vectors_config = collection_info.config.params.sparse_vectors - - if sparse_vector_name not in sparse_vectors_config: - raise QdrantException( - f"Existing Qdrant collection {collection_name} does not " - f"contain sparse vector named {sparse_vector_name}." - f"Did you mean one of {', '.join(sparse_vectors_config.keys())}?" - ) - except (UnexpectedResponse, RpcError, ValueError): - raise QdrantException( - f"Qdrant collection {collection_name} does not exist." - ) - return values - - def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[Document]: - from qdrant_client import QdrantClient, models - - client = cast(QdrantClient, self.client) - query_indices, query_values = self.sparse_encoder(query) - results = client.search( - self.collection_name, - query_filter=self.filter, - query_vector=models.NamedSparseVector( - name=self.sparse_vector_name, - vector=models.SparseVector( - indices=query_indices, - values=query_values, - ), - ), - limit=self.k, - with_vectors=False, - **self.search_options, - ) - return [ - Qdrant._document_from_scored_point( - point, - self.collection_name, - self.content_payload_key, - self.metadata_payload_key, - ) - for point in results - ] - - def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: - """Run more documents through the embeddings and add to the vectorstore. - - Args: - documents (List[Document]: Documents to add to the vectorstore. - - Returns: - List[str]: List of IDs of the added texts. - """ - texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] - return self.add_texts(texts, metadatas, **kwargs) - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[Sequence[str]] = None, - batch_size: int = 64, - **kwargs: Any, - ) -> List[str]: - from qdrant_client import QdrantClient - - added_ids = [] - client = cast(QdrantClient, self.client) - for batch_ids, points in self._generate_rest_batches( - texts, metadatas, ids, batch_size - ): - client.upsert(self.collection_name, points=points, **kwargs) - added_ids.extend(batch_ids) - - return added_ids - - def _generate_rest_batches( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[Sequence[str]] = None, - batch_size: int = 64, - ) -> Generator[Tuple[List[str], List[Any]], None, None]: - from qdrant_client import models as rest - - texts_iterator = iter(texts) - metadatas_iterator = iter(metadatas or []) - ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) - while batch_texts := list(islice(texts_iterator, batch_size)): - # Take the corresponding metadata and id for each text in a batch - batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None - batch_ids = list(islice(ids_iterator, batch_size)) - - # Generate the sparse embeddings for all the texts in a batch - batch_embeddings: List[Tuple[List[int], List[float]]] = [ - self.sparse_encoder(text) for text in batch_texts - ] - - points = [ - rest.PointStruct( - id=point_id, - vector={ - self.sparse_vector_name: rest.SparseVector( - indices=sparse_vector[0], - values=sparse_vector[1], - ) - }, - payload=payload, - ) - for point_id, sparse_vector, payload in zip( - batch_ids, - batch_embeddings, - Qdrant._build_payloads( - batch_texts, - batch_metadatas, - self.content_payload_key, - self.metadata_payload_key, - ), - ) - ] - - yield batch_ids, points diff --git a/libs/community/langchain_community/retrievers/web_research.py b/libs/community/langchain_community/retrievers/web_research.py index be684eaed..2eb422cfc 100644 --- a/libs/community/langchain_community/retrievers/web_research.py +++ b/libs/community/langchain_community/retrievers/web_research.py @@ -20,7 +20,7 @@ from langchain_community.document_loaders import AsyncHtmlLoader from langchain_community.document_transformers import Html2TextTransformer from langchain_community.llms import LlamaCpp -from langchain_community.utilities import GoogleSearchAPIWrapper +from langchain_community.utilities.google_search import GoogleSearchAPIWrapper logger = logging.getLogger(__name__) @@ -82,12 +82,12 @@ class WebResearchRetriever(BaseRetriever): ) allow_dangerous_requests: bool = False - """A flag to force users to acknowledge the risks of SSRF attacks when using + """A flag to force users to acknowledge the risks of SSRF attacks when using this retriever. - + Users should set this flag to `True` if they have taken the necessary precautions to prevent SSRF attacks when using this retriever. - + For example, users can run the requests through a properly configured proxy and prevent the crawler from accidentally crawling internal resources. """ diff --git a/libs/community/langchain_community/tools/__init__.py b/libs/community/langchain_community/tools/__init__.py index de486cfbb..b78644ee2 100644 --- a/libs/community/langchain_community/tools/__init__.py +++ b/libs/community/langchain_community/tools/__init__.py @@ -146,13 +146,6 @@ from langchain_community.tools.google_cloud.texttospeech import ( GoogleCloudTextToSpeechTool, ) - from langchain_community.tools.google_places.tool import ( - GooglePlacesTool, - ) - from langchain_community.tools.google_search.tool import ( - GoogleSearchResults, - GoogleSearchRun, - ) from langchain_community.tools.google_serper.tool import ( GoogleSerperResults, GoogleSerperRun, @@ -413,9 +406,6 @@ "GmailSendMessage", "GoogleBooksQueryRun", "GoogleCloudTextToSpeechTool", - "GooglePlacesTool", - "GoogleSearchResults", - "GoogleSearchRun", "GoogleSerperResults", "GoogleSerperRun", "HumanInputRun", @@ -567,9 +557,6 @@ "GmailSendMessage": "langchain_community.tools.gmail", "GoogleBooksQueryRun": "langchain_community.tools.google_books", "GoogleCloudTextToSpeechTool": "langchain_community.tools.google_cloud.texttospeech", # noqa: E501 - "GooglePlacesTool": "langchain_community.tools.google_places.tool", - "GoogleSearchResults": "langchain_community.tools.google_search.tool", - "GoogleSearchRun": "langchain_community.tools.google_search.tool", "GoogleSerperResults": "langchain_community.tools.google_serper.tool", "GoogleSerperRun": "langchain_community.tools.google_serper.tool", "HumanInputRun": "langchain_community.tools.human.tool", diff --git a/libs/community/langchain_community/tools/amadeus/closest_airport.py b/libs/community/langchain_community/tools/amadeus/closest_airport.py index 9523f73af..0aa5b304c 100644 --- a/libs/community/langchain_community/tools/amadeus/closest_airport.py +++ b/libs/community/langchain_community/tools/amadeus/closest_airport.py @@ -4,7 +4,7 @@ from langchain_core.language_models import BaseLanguageModel from pydantic import BaseModel, Field, model_validator -from langchain_community.chat_models import ChatOpenAI +from langchain_community.chat_models.openai import ChatOpenAI from langchain_community.tools.amadeus.base import AmadeusBaseTool diff --git a/libs/community/langchain_community/tools/google_places/__init__.py b/libs/community/langchain_community/tools/google_places/__init__.py deleted file mode 100644 index 6d3b948ea..000000000 --- a/libs/community/langchain_community/tools/google_places/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Google Places API Toolkit.""" - -from langchain_community.tools.google_places.tool import GooglePlacesTool - -__all__ = ["GooglePlacesTool"] diff --git a/libs/community/langchain_community/tools/google_places/tool.py b/libs/community/langchain_community/tools/google_places/tool.py deleted file mode 100644 index 77a146907..000000000 --- a/libs/community/langchain_community/tools/google_places/tool.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Tool for the Google search API.""" - -from typing import Optional, Type - -from langchain_core._api.deprecation import deprecated -from langchain_core.callbacks import CallbackManagerForToolRun -from langchain_core.tools import BaseTool -from pydantic import BaseModel, Field - -from langchain_community.utilities.google_places_api import GooglePlacesAPIWrapper - - -class GooglePlacesSchema(BaseModel): - """Input for GooglePlacesTool.""" - - query: str = Field(..., description="Query for google maps") - - -@deprecated( - since="0.0.33", - removal="1.0", - alternative_import="langchain_google_community.GooglePlacesTool", -) -class GooglePlacesTool(BaseTool): - """Tool that queries the Google places API.""" - - name: str = "google_places" - description: str = ( - "A wrapper around Google Places. " - "Useful for when you need to validate or " - "discover addressed from ambiguous text. " - "Input should be a search query." - ) - api_wrapper: GooglePlacesAPIWrapper = Field(default_factory=GooglePlacesAPIWrapper) - args_schema: Type[BaseModel] = GooglePlacesSchema - - def _run( - self, - query: str, - run_manager: Optional[CallbackManagerForToolRun] = None, - ) -> str: - """Use the tool.""" - return self.api_wrapper.run(query) diff --git a/libs/community/langchain_community/tools/google_search/__init__.py b/libs/community/langchain_community/tools/google_search/__init__.py deleted file mode 100644 index 08eccf0a3..000000000 --- a/libs/community/langchain_community/tools/google_search/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Google Search API Toolkit.""" - -from langchain_community.tools.google_search.tool import ( - GoogleSearchResults, - GoogleSearchRun, -) - -__all__ = ["GoogleSearchRun", "GoogleSearchResults"] diff --git a/libs/community/langchain_community/tools/google_search/tool.py b/libs/community/langchain_community/tools/google_search/tool.py deleted file mode 100644 index 3ba05079d..000000000 --- a/libs/community/langchain_community/tools/google_search/tool.py +++ /dev/null @@ -1,60 +0,0 @@ -"""Tool for the Google search API.""" - -from typing import Optional - -from langchain_core._api.deprecation import deprecated -from langchain_core.callbacks import CallbackManagerForToolRun -from langchain_core.tools import BaseTool - -from langchain_community.utilities.google_search import GoogleSearchAPIWrapper - - -@deprecated( - since="0.0.33", - removal="1.0", - alternative_import="langchain_google_community.GoogleSearchRun", -) -class GoogleSearchRun(BaseTool): - """Tool that queries the Google search API.""" - - name: str = "google_search" - description: str = ( - "A wrapper around Google Search. " - "Useful for when you need to answer questions about current events. " - "Input should be a search query." - ) - api_wrapper: GoogleSearchAPIWrapper - - def _run( - self, - query: str, - run_manager: Optional[CallbackManagerForToolRun] = None, - ) -> str: - """Use the tool.""" - return self.api_wrapper.run(query) - - -@deprecated( - since="0.0.33", - removal="1.0", - alternative_import="langchain_google_community.GoogleSearchResults", -) -class GoogleSearchResults(BaseTool): - """Tool that queries the Google Search API and gets back json.""" - - name: str = "google_search_results_json" - description: str = ( - "A wrapper around Google Search. " - "Useful for when you need to answer questions about current events. " - "Input should be a search query. Output is a JSON array of the query results" - ) - num_results: int = 4 - api_wrapper: GoogleSearchAPIWrapper - - def _run( - self, - query: str, - run_manager: Optional[CallbackManagerForToolRun] = None, - ) -> str: - """Use the tool.""" - return str(self.api_wrapper.results(query, self.num_results)) diff --git a/libs/community/langchain_community/tools/powerbi/tool.py b/libs/community/langchain_community/tools/powerbi/tool.py index c5ec51e5b..5806541c9 100644 --- a/libs/community/langchain_community/tools/powerbi/tool.py +++ b/libs/community/langchain_community/tools/powerbi/tool.py @@ -11,7 +11,6 @@ from langchain_core.tools import BaseTool from pydantic import ConfigDict, Field, model_validator -from langchain_community.chat_models.openai import _import_tiktoken from langchain_community.tools.powerbi.prompt import ( BAD_REQUEST_RESPONSE, DEFAULT_FEWSHOT_EXAMPLES, @@ -22,6 +21,18 @@ logger = logging.getLogger(__name__) +def _import_tiktoken() -> Any: + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "This is needed in order to calculate get_token_ids. " + "Please install it with `pip install tiktoken`." + ) + return tiktoken + + class QueryPowerBITool(BaseTool): """Tool for querying a Power BI Dataset.""" diff --git a/libs/community/langchain_community/utilities/__init__.py b/libs/community/langchain_community/utilities/__init__.py index 0174d37c0..461bce79f 100644 --- a/libs/community/langchain_community/utilities/__init__.py +++ b/libs/community/langchain_community/utilities/__init__.py @@ -1,7 +1,7 @@ -"""**Utilities** are the integrations with third-part systems and packages. +"""**Utilities** are the integrations with third-party systems and packages. -Other LangChain classes use **Utilities** to interact with third-part systems -and packages. +Other LangChain classes use **Utilities** to interact with third-party systems and +packages. """ import importlib @@ -57,15 +57,9 @@ from langchain_community.utilities.google_lens import ( GoogleLensAPIWrapper, ) - from langchain_community.utilities.google_places_api import ( - GooglePlacesAPIWrapper, - ) from langchain_community.utilities.google_scholar import ( GoogleScholarAPIWrapper, ) - from langchain_community.utilities.google_search import ( - GoogleSearchAPIWrapper, - ) from langchain_community.utilities.google_serper import ( GoogleSerperAPIWrapper, ) @@ -107,9 +101,6 @@ from langchain_community.utilities.openweathermap import ( OpenWeatherMapAPIWrapper, ) - from langchain_community.utilities.oracleai import ( - OracleSummary, - ) from langchain_community.utilities.outline import ( OutlineAPIWrapper, ) @@ -192,9 +183,7 @@ "GoogleFinanceAPIWrapper", "GoogleJobsAPIWrapper", "GoogleLensAPIWrapper", - "GooglePlacesAPIWrapper", "GoogleScholarAPIWrapper", - "GoogleSearchAPIWrapper", "GoogleSerperAPIWrapper", "GoogleTrendsAPIWrapper", "GraphQLAPIWrapper", @@ -211,7 +200,6 @@ "NasaAPIWrapper", "NutritionAIAPI", "OpenWeatherMapAPIWrapper", - "OracleSummary", "OutlineAPIWrapper", "Portkey", "PowerBIDataset", @@ -256,9 +244,7 @@ "GoogleFinanceAPIWrapper": "langchain_community.utilities.google_finance", "GoogleJobsAPIWrapper": "langchain_community.utilities.google_jobs", "GoogleLensAPIWrapper": "langchain_community.utilities.google_lens", - "GooglePlacesAPIWrapper": "langchain_community.utilities.google_places_api", "GoogleScholarAPIWrapper": "langchain_community.utilities.google_scholar", - "GoogleSearchAPIWrapper": "langchain_community.utilities.google_search", "GoogleSerperAPIWrapper": "langchain_community.utilities.google_serper", "GoogleTrendsAPIWrapper": "langchain_community.utilities.google_trends", "GraphQLAPIWrapper": "langchain_community.utilities.graphql", @@ -275,7 +261,6 @@ "NasaAPIWrapper": "langchain_community.utilities.nasa", "NutritionAIAPI": "langchain_community.utilities.passio_nutrition_ai", "OpenWeatherMapAPIWrapper": "langchain_community.utilities.openweathermap", - "OracleSummary": "langchain_community.utilities.oracleai", "OutlineAPIWrapper": "langchain_community.utilities.outline", "Portkey": "langchain_community.utilities.portkey", "PowerBIDataset": "langchain_community.utilities.powerbi", diff --git a/libs/community/langchain_community/utilities/google_places_api.py b/libs/community/langchain_community/utilities/google_places_api.py deleted file mode 100644 index 423aeee6e..000000000 --- a/libs/community/langchain_community/utilities/google_places_api.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Chain that calls Google Places API.""" - -import logging -from typing import Any, Dict, Optional - -from langchain_core._api.deprecation import deprecated -from langchain_core.utils import get_from_dict_or_env -from pydantic import BaseModel, ConfigDict, model_validator - - -@deprecated( - since="0.0.33", - removal="1.0", - alternative_import="langchain_google_community.GooglePlacesAPIWrapper", -) -class GooglePlacesAPIWrapper(BaseModel): - """Wrapper around Google Places API. - - To use, you should have the ``googlemaps`` python package installed, - **an API key for the google maps platform**, - and the environment variable ''GPLACES_API_KEY'' - set with your API key , or pass 'gplaces_api_key' - as a named parameter to the constructor. - - By default, this will return the all the results on the input query. - You can use the top_k_results argument to limit the number of results. - - Example: - .. code-block:: python - - - from langchain_community.utilities import GooglePlacesAPIWrapper - gplaceapi = GooglePlacesAPIWrapper() - """ - - gplaces_api_key: Optional[str] = None - google_map_client: Any = None #: :meta private: - top_k_results: Optional[int] = None - - model_config = ConfigDict( - arbitrary_types_allowed=True, - extra="forbid", - ) - - @model_validator(mode="before") - @classmethod - def validate_environment(cls, values: Dict) -> Any: - """Validate that api key is in your environment variable.""" - gplaces_api_key = get_from_dict_or_env( - values, "gplaces_api_key", "GPLACES_API_KEY" - ) - values["gplaces_api_key"] = gplaces_api_key - try: - import googlemaps - - values["google_map_client"] = googlemaps.Client(gplaces_api_key) - except ImportError: - raise ImportError( - "Could not import googlemaps python package. " - "Please install it with `pip install googlemaps`." - ) - return values - - def run(self, query: str) -> str: - """Run Places search and get k number of places that exists that match.""" - search_results = self.google_map_client.places(query)["results"] - num_to_return = len(search_results) - - places = [] - - if num_to_return == 0: - return "Google Places did not find any places that match the description" - - num_to_return = ( - num_to_return - if self.top_k_results is None - else min(num_to_return, self.top_k_results) - ) - - for i in range(num_to_return): - result = search_results[i] - details = self.fetch_place_details(result["place_id"]) - - if details is not None: - places.append(details) - - return "\n".join([f"{i + 1}. {item}" for i, item in enumerate(places)]) - - def fetch_place_details(self, place_id: str) -> Optional[str]: - try: - place_details = self.google_map_client.place(place_id) - place_details["place_id"] = place_id - formatted_details = self.format_place_details(place_details) - return formatted_details - except Exception as e: - logging.error(f"An Error occurred while fetching place details: {e}") - return None - - def format_place_details(self, place_details: Dict[str, Any]) -> Optional[str]: - try: - name = place_details.get("result", {}).get("name", "Unknown") - address = place_details.get("result", {}).get( - "formatted_address", "Unknown" - ) - phone_number = place_details.get("result", {}).get( - "formatted_phone_number", "Unknown" - ) - website = place_details.get("result", {}).get("website", "Unknown") - place_id = place_details.get("result", {}).get("place_id", "Unknown") - - formatted_details = ( - f"{name}\nAddress: {address}\n" - f"Google place ID: {place_id}\n" - f"Phone: {phone_number}\nWebsite: {website}\n\n" - ) - return formatted_details - except Exception as e: - logging.error(f"An error occurred while formatting place details: {e}") - return None diff --git a/libs/community/langchain_community/utilities/google_search.py b/libs/community/langchain_community/utilities/google_search.py index 2037b34c8..65acf48c8 100644 --- a/libs/community/langchain_community/utilities/google_search.py +++ b/libs/community/langchain_community/utilities/google_search.py @@ -1,4 +1,4 @@ -"""Util that calls Google Search.""" +"""DO NOT USE. Kept for backward compatibility for web search retriever.""" from typing import Any, Dict, List, Optional diff --git a/libs/community/langchain_community/utilities/oracleai.py b/libs/community/langchain_community/utilities/oracleai.py deleted file mode 100644 index e8587cac8..000000000 --- a/libs/community/langchain_community/utilities/oracleai.py +++ /dev/null @@ -1,214 +0,0 @@ -# Authors: -# Harichandan Roy (hroy) -# David Jiang (ddjiang) -# -# ----------------------------------------------------------------------------- -# oracleai.py -# ----------------------------------------------------------------------------- - -from __future__ import annotations - -import json -import logging -import traceback -from typing import TYPE_CHECKING, Any, Dict, List, Optional - -from langchain_core._api import deprecated -from langchain_core.documents import Document - -if TYPE_CHECKING: - from oracledb import Connection - -logger = logging.getLogger(__name__) - -"""OracleSummary class""" - - -@deprecated( - since="0.3.30", - removal="1.0", - message=( - "This class is deprecated and will be removed in a future release. " - "Instead, please use `OracleSummary` from the " - "`langchain-oracledb` package. " - "For more information, refer to ." - ), - alternative="from langchain_oracledb.utilities import OracleSummary;", - pending=False, -) -class OracleSummary: - """Get Summary - Args: - conn: Oracle Connection, - params: Summary parameters, - proxy: Proxy - """ - - def __init__( - self, conn: Connection, params: Dict[str, Any], proxy: Optional[str] = None - ): - self.conn = conn - self.proxy = proxy - self.summary_params = params - - def get_summary(self, docs: Any) -> List[str]: - """Get the summary of the input docs. - Args: - docs: The documents to generate summary for. - Allowed input types: str, Document, List[str], List[Document] - Returns: - List of summary text, one for each input doc. - """ - - try: - import oracledb - except ImportError as e: - raise ImportError( - "Unable to import oracledb, please install with " - "`pip install -U oracledb`." - ) from e - - if docs is None: - return [] - - results: List[str] = [] - try: - oracledb.defaults.fetch_lobs = False - cursor = self.conn.cursor() - - if self.proxy: - cursor.execute( - "begin utl_http.set_proxy(:proxy); end;", proxy=self.proxy - ) - - if isinstance(docs, str): - results = [] - - summary = cursor.var(oracledb.DB_TYPE_CLOB) - cursor.execute( - """ - declare - input clob; - begin - input := :data; - :summ := dbms_vector_chain.utl_to_summary(input, json(:params)); - end;""", - data=docs, - params=json.dumps(self.summary_params), - summ=summary, - ) - - if summary is None: - results.append("") - else: - results.append(str(summary.getvalue())) - - elif isinstance(docs, Document): - results = [] - - summary = cursor.var(oracledb.DB_TYPE_CLOB) - cursor.execute( - """ - declare - input clob; - begin - input := :data; - :summ := dbms_vector_chain.utl_to_summary(input, json(:params)); - end;""", - data=docs.page_content, - params=json.dumps(self.summary_params), - summ=summary, - ) - - if summary is None: - results.append("") - else: - results.append(str(summary.getvalue())) - - elif isinstance(docs, List): - results = [] - - for doc in docs: - summary = cursor.var(oracledb.DB_TYPE_CLOB) - if isinstance(doc, str): - cursor.execute( - """ - declare - input clob; - begin - input := :data; - :summ := dbms_vector_chain.utl_to_summary(input, - json(:params)); - end;""", - data=doc, - params=json.dumps(self.summary_params), - summ=summary, - ) - - elif isinstance(doc, Document): - cursor.execute( - """ - declare - input clob; - begin - input := :data; - :summ := dbms_vector_chain.utl_to_summary(input, - json(:params)); - end;""", - data=doc.page_content, - params=json.dumps(self.summary_params), - summ=summary, - ) - - else: - raise Exception("Invalid input type") - - if summary is None: - results.append("") - else: - results.append(str(summary.getvalue())) - - else: - raise Exception("Invalid input type") - - cursor.close() - return results - - except Exception as ex: - logger.info(f"An exception occurred :: {ex}") - traceback.print_exc() - cursor.close() - raise - - -# uncomment the following code block to run the test - -""" -# A sample unit test. - -''' get the Oracle connection ''' -conn = oracledb.connect( - user="", - password="", - dsn="") -print("Oracle connection is established...") - -''' params ''' -summary_params = {"provider": "database","glevel": "S", - "numParagraphs": 1,"language": "english"} -proxy = "" - -''' instance ''' -summ = OracleSummary(conn=conn, params=summary_params, proxy=proxy) - -summary = summ.get_summary("In the heart of the forest, " + - "a lone fox ventured out at dusk, seeking a lost treasure. " + - "With each step, memories flooded back, guiding its path. " + - "As the moon rose high, illuminating the night, the fox unearthed " + - "not gold, but a forgotten friendship, worth more than any riches.") -print(f"Summary generated by OracleSummary: {summary}") - -conn.close() -print("Connection is closed.") - -""" diff --git a/libs/community/langchain_community/utils/__init__.py b/libs/community/langchain_community/utils/__init__.py index de1316ff4..657d047d3 100644 --- a/libs/community/langchain_community/utils/__init__.py +++ b/libs/community/langchain_community/utils/__init__.py @@ -1,3 +1 @@ -""" -**Utility functions** for LangChain. -""" +"""`langchain-community` utilities.""" diff --git a/libs/community/langchain_community/utils/openai_functions.py b/libs/community/langchain_community/utils/openai_functions.py index b020c8aae..067d76684 100644 --- a/libs/community/langchain_community/utils/openai_functions.py +++ b/libs/community/langchain_community/utils/openai_functions.py @@ -1,4 +1,4 @@ -# these stubs are just for backwards compatibility +# These are just for backwards compatibility from langchain_core.utils.function_calling import ( FunctionDescription, diff --git a/libs/community/langchain_community/vectorstores/__init__.py b/libs/community/langchain_community/vectorstores/__init__.py index c769728c0..b80a7abe2 100644 --- a/libs/community/langchain_community/vectorstores/__init__.py +++ b/libs/community/langchain_community/vectorstores/__init__.py @@ -3,20 +3,6 @@ One of the most common ways to store and search over unstructured data is to embed it and store the resulting embedding vectors, and then query the store and retrieve the data that are 'most similar' to the embedded query. - -**Class hierarchy:** - -.. code-block:: - - VectorStore --> # Examples: Annoy, FAISS, Milvus - - BaseRetriever --> VectorStoreRetriever --> Retriever # Example: VespaRetriever - -**Main helpers:** - -.. code-block:: - - Embeddings, Document """ # noqa: E501 import importlib @@ -95,12 +81,6 @@ from langchain_community.vectorstores.dashvector import ( DashVector, ) - from langchain_community.vectorstores.databricks_vector_search import ( - DatabricksVectorSearch, - ) - from langchain_community.vectorstores.deeplake import ( - DeepLake, - ) from langchain_community.vectorstores.dingo import ( Dingo, ) @@ -130,9 +110,6 @@ from langchain_community.vectorstores.faiss import ( FAISS, ) - from langchain_community.vectorstores.hanavector import ( - HanaDB, - ) from langchain_community.vectorstores.hologres import ( Hologres, ) @@ -166,34 +143,19 @@ from langchain_community.vectorstores.marqo import ( Marqo, ) - from langchain_community.vectorstores.matching_engine import ( - MatchingEngine, - ) from langchain_community.vectorstores.meilisearch import ( Meilisearch, ) - from langchain_community.vectorstores.milvus import ( - Milvus, - ) from langchain_community.vectorstores.momento_vector_index import ( MomentoVectorIndex, ) - from langchain_community.vectorstores.mongodb_atlas import ( - MongoDBAtlasVectorSearch, - ) from langchain_community.vectorstores.myscale import ( MyScale, MyScaleSettings, ) - from langchain_community.vectorstores.neo4j_vector import ( - Neo4jVector, - ) from langchain_community.vectorstores.opensearch_vector_search import ( OpenSearchVectorSearch, ) - from langchain_community.vectorstores.oraclevs import ( - OracleVS, - ) from langchain_community.vectorstores.pathway import ( PathwayVectorClient, ) @@ -203,12 +165,6 @@ from langchain_community.vectorstores.pgvector import ( PGVector, ) - from langchain_community.vectorstores.pinecone import ( - Pinecone, - ) - from langchain_community.vectorstores.qdrant import ( - Qdrant, - ) from langchain_community.vectorstores.redis import ( Redis, ) @@ -279,9 +235,6 @@ from langchain_community.vectorstores.vald import ( Vald, ) - from langchain_community.vectorstores.vdms import ( - VDMS, - ) from langchain_community.vectorstores.vearch import ( Vearch, ) @@ -294,9 +247,6 @@ from langchain_community.vectorstores.vlite import ( VLite, ) - from langchain_community.vectorstores.weaviate import ( - Weaviate, - ) from langchain_community.vectorstores.yellowbrick import ( Yellowbrick, ) @@ -335,8 +285,6 @@ "ClickhouseSettings", "CouchbaseVectorStore", "DashVector", - "DatabricksVectorSearch", - "DeepLake", "Dingo", "DistanceStrategy", "DocArrayHnswSearch", @@ -349,7 +297,6 @@ "ElasticsearchStore", "Epsilla", "FAISS", - "HanaDB", "Hologres", "InMemoryVectorStore", "InfinispanVS", @@ -362,23 +309,16 @@ "ManticoreSearch", "ManticoreSearchSettings", "Marqo", - "MatchingEngine", "Meilisearch", - "Milvus", "MomentoVectorIndex", - "MongoDBAtlasVectorSearch", "MyScale", "MyScaleSettings", - "Neo4jVector", "NeuralDBClientVectorStore", "NeuralDBVectorStore", - "OracleVS", "OpenSearchVectorSearch", "PGEmbedding", "PGVector", "PathwayVectorClient", - "Pinecone", - "Qdrant", "Redis", "Relyt", "Rockset", @@ -400,14 +340,12 @@ "Typesense", "UpstashVectorStore", "USearch", - "VDMS", "Vald", "Vearch", "Vectara", "VectorStore", "VespaStore", "VLite", - "Weaviate", "Yellowbrick", "ZepVectorStore", "ZepCloudVectorStore", @@ -439,8 +377,6 @@ "ClickhouseSettings": "langchain_community.vectorstores.clickhouse", "CouchbaseVectorStore": "langchain_community.vectorstores.couchbase", "DashVector": "langchain_community.vectorstores.dashvector", - "DatabricksVectorSearch": "langchain_community.vectorstores.databricks_vector_search", # noqa: E501 - "DeepLake": "langchain_community.vectorstores.deeplake", "Dingo": "langchain_community.vectorstores.dingo", "DistanceStrategy": "langchain_community.vectorstores.kinetica", "DocArrayHnswSearch": "langchain_community.vectorstores.docarray", @@ -453,7 +389,6 @@ "ElasticsearchStore": "langchain_community.vectorstores.elasticsearch", "Epsilla": "langchain_community.vectorstores.epsilla", "FAISS": "langchain_community.vectorstores.faiss", - "HanaDB": "langchain_community.vectorstores.hanavector", "Hologres": "langchain_community.vectorstores.hologres", "InfinispanVS": "langchain_community.vectorstores.infinispanvs", "InMemoryVectorStore": "langchain_community.vectorstores.inmemory", @@ -466,23 +401,16 @@ "ManticoreSearch": "langchain_community.vectorstores.manticore_search", "ManticoreSearchSettings": "langchain_community.vectorstores.manticore_search", "Marqo": "langchain_community.vectorstores.marqo", - "MatchingEngine": "langchain_community.vectorstores.matching_engine", "Meilisearch": "langchain_community.vectorstores.meilisearch", - "Milvus": "langchain_community.vectorstores.milvus", "MomentoVectorIndex": "langchain_community.vectorstores.momento_vector_index", - "MongoDBAtlasVectorSearch": "langchain_community.vectorstores.mongodb_atlas", "MyScale": "langchain_community.vectorstores.myscale", "MyScaleSettings": "langchain_community.vectorstores.myscale", - "Neo4jVector": "langchain_community.vectorstores.neo4j_vector", "NeuralDBClientVectorStore": "langchain_community.vectorstores.thirdai_neuraldb", "NeuralDBVectorStore": "langchain_community.vectorstores.thirdai_neuraldb", "OpenSearchVectorSearch": "langchain_community.vectorstores.opensearch_vector_search", # noqa: E501 - "OracleVS": "langchain_community.vectorstores.oraclevs", "PathwayVectorClient": "langchain_community.vectorstores.pathway", "PGEmbedding": "langchain_community.vectorstores.pgembedding", "PGVector": "langchain_community.vectorstores.pgvector", - "Pinecone": "langchain_community.vectorstores.pinecone", - "Qdrant": "langchain_community.vectorstores.qdrant", "Redis": "langchain_community.vectorstores.redis", "Relyt": "langchain_community.vectorstores.relyt", "Rockset": "langchain_community.vectorstores.rocksetdb", @@ -505,13 +433,11 @@ "UpstashVectorStore": "langchain_community.vectorstores.upstash", "USearch": "langchain_community.vectorstores.usearch", "Vald": "langchain_community.vectorstores.vald", - "VDMS": "langchain_community.vectorstores.vdms", "Vearch": "langchain_community.vectorstores.vearch", "Vectara": "langchain_community.vectorstores.vectara", "VectorStore": "langchain_core.vectorstores", "VespaStore": "langchain_community.vectorstores.vespa", "VLite": "langchain_community.vectorstores.vlite", - "Weaviate": "langchain_community.vectorstores.weaviate", "Yellowbrick": "langchain_community.vectorstores.yellowbrick", "ZepVectorStore": "langchain_community.vectorstores.zep", "ZepCloudVectorStore": "langchain_community.vectorstores.zep_cloud", diff --git a/libs/community/langchain_community/vectorstores/apache_doris.py b/libs/community/langchain_community/vectorstores/apache_doris.py index 6b5357260..4a86e7a4f 100644 --- a/libs/community/langchain_community/vectorstores/apache_doris.py +++ b/libs/community/langchain_community/vectorstores/apache_doris.py @@ -108,7 +108,7 @@ def __init__( config (ApacheDorisSettings): Apache Doris client configuration information. """ try: - import pymysql # type: ignore[import-untyped] + import pymysql # type: ignore[import-untyped, unused-ignore] except ImportError: raise ImportError( "Could not import pymysql python package. " @@ -136,7 +136,7 @@ def __init__( dim = len(embedding.embed_query("test")) self.schema = f"""\ -CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( +CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( {self.config.column_map["id"]} varchar(50), {self.config.column_map["document"]} string, {self.config.column_map["embedding"]} array, @@ -324,10 +324,10 @@ def _build_query_sql( where_str = "" q_str = f""" - SELECT + SELECT id as id, - {self.config.column_map["document"]} as document, - {self.config.column_map["metadata"]} as metadata, + {self.config.column_map["document"]} as document, + {self.config.column_map["metadata"]} as metadata, cosine_distance(array[{q_emb_str}], {self.config.column_map["embedding"]}) as dist, {self.config.column_map["embedding"]} as embedding diff --git a/libs/community/langchain_community/vectorstores/databricks_vector_search.py b/libs/community/langchain_community/vectorstores/databricks_vector_search.py deleted file mode 100644 index b779341a7..000000000 --- a/libs/community/langchain_community/vectorstores/databricks_vector_search.py +++ /dev/null @@ -1,693 +0,0 @@ -from __future__ import annotations - -import json -import logging -import uuid -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - List, - Optional, - Tuple, - Type, -) - -import numpy as np -from langchain_core._api import deprecated, warn_deprecated -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from langchain_core.vectorstores import VST, VectorStore - -from langchain_community.vectorstores.utils import maximal_marginal_relevance - -if TYPE_CHECKING: - from databricks.vector_search.client import VectorSearchIndex - -logger = logging.getLogger(__name__) - - -@deprecated( - since="0.3.3", - removal="1.0", - alternative_import="databricks_langchain.DatabricksVectorSearch", -) -class DatabricksVectorSearch(VectorStore): - """`Databricks Vector Search` vector store. - - To use, you should have the ``databricks-vectorsearch`` python package installed. - - Example: - .. code-block:: python - - from langchain_community.vectorstores import DatabricksVectorSearch - from databricks.vector_search.client import VectorSearchClient - - vs_client = VectorSearchClient() - vs_index = vs_client.get_index( - endpoint_name="vs_endpoint", - index_name="ml.llm.index" - ) - vectorstore = DatabricksVectorSearch(vs_index) - - Args: - index: A Databricks Vector Search index object. - embedding: The embedding model. - Required for direct-access index or delta-sync index - with self-managed embeddings. - text_column: The name of the text column to use for the embeddings. - Required for direct-access index or delta-sync index - with self-managed embeddings. - Make sure the text column specified is in the index. - columns: The list of column names to get when doing the search. - Defaults to ``[primary_key, text_column]``. - - Delta-sync index with Databricks-managed embeddings manages the ingestion, deletion, - and embedding for you. - Manually ingestion/deletion of the documents/texts is not supported for delta-sync - index. - - If you want to use a delta-sync index with self-managed embeddings, you need to - provide the embedding model and text column name to use for the embeddings. - - Example: - .. code-block:: python - - from langchain_community.vectorstores import DatabricksVectorSearch - from databricks.vector_search.client import VectorSearchClient - from langchain_community.embeddings.openai import OpenAIEmbeddings - - vs_client = VectorSearchClient() - vs_index = vs_client.get_index( - endpoint_name="vs_endpoint", - index_name="ml.llm.index" - ) - vectorstore = DatabricksVectorSearch( - index=vs_index, - embedding=OpenAIEmbeddings(), - text_column="document_content" - ) - - If you want to manage the documents ingestion/deletion yourself, you can use a - direct-access index. - - Example: - .. code-block:: python - - from langchain_community.vectorstores import DatabricksVectorSearch - from databricks.vector_search.client import VectorSearchClient - from langchain_community.embeddings.openai import OpenAIEmbeddings - - vs_client = VectorSearchClient() - vs_index = vs_client.get_index( - endpoint_name="vs_endpoint", - index_name="ml.llm.index" - ) - vectorstore = DatabricksVectorSearch( - index=vs_index, - embedding=OpenAIEmbeddings(), - text_column="document_content" - ) - vectorstore.add_texts( - texts=["text1", "text2"] - ) - - For more information on Databricks Vector Search, see `Databricks Vector Search - documentation: https://docs.databricks.com/en/generative-ai/vector-search.html. - - """ - - def __init__( - self, - index: VectorSearchIndex, - *, - embedding: Optional[Embeddings] = None, - text_column: Optional[str] = None, - columns: Optional[List[str]] = None, - ): - try: - from databricks.vector_search.client import VectorSearchIndex - except ImportError as e: - raise ImportError( - "Could not import databricks-vectorsearch python package. " - "Please install it with `pip install databricks-vectorsearch`." - ) from e - # index - self.index = index - if not isinstance(index, VectorSearchIndex): - raise TypeError("index must be of type VectorSearchIndex.") - - # index_details - index_details = self.index.describe() - self.primary_key = index_details["primary_key"] - self.index_type = index_details.get("index_type") - self._delta_sync_index_spec = index_details.get("delta_sync_index_spec", dict()) - self._direct_access_index_spec = index_details.get( - "direct_access_index_spec", dict() - ) - - # text_column - if self._is_databricks_managed_embeddings(): - index_source_column = self._embedding_source_column_name() - # check if input text column matches the source column of the index - if text_column is not None and text_column != index_source_column: - raise ValueError( - f"text_column '{text_column}' does not match with the " - f"source column of the index: '{index_source_column}'." - ) - self.text_column = index_source_column - else: - self._require_arg(text_column, "text_column") - self.text_column = text_column - - # columns - self.columns = columns or [] - # add primary key column and source column if not in columns - if self.primary_key not in self.columns: - self.columns.append(self.primary_key) - if self.text_column and self.text_column not in self.columns: - self.columns.append(self.text_column) - - # Validate specified columns are in the index - if self._is_direct_access_index(): - index_schema = self._index_schema() - if index_schema: - for col in self.columns: - if col not in index_schema: - raise ValueError( - f"column '{col}' is not in the index's schema." - ) - - # embedding model - if not self._is_databricks_managed_embeddings(): - # embedding model is required for direct-access index - # or delta-sync index with self-managed embedding - self._require_arg(embedding, "embedding") - self._embedding = embedding - # validate dimension matches - index_embedding_dimension = self._embedding_vector_column_dimension() - if index_embedding_dimension is not None: - inferred_embedding_dimension = self._infer_embedding_dimension() - if inferred_embedding_dimension != index_embedding_dimension: - raise ValueError( - f"embedding model's dimension '{inferred_embedding_dimension}' " - f"does not match with the index's dimension " - f"'{index_embedding_dimension}'." - ) - else: - if embedding is not None: - logger.warning( - "embedding model is not used in delta-sync index with " - "Databricks-managed embeddings." - ) - self._embedding = None - - @classmethod - def from_texts( - cls: Type[VST], - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[Dict]] = None, - **kwargs: Any, - ) -> VST: - raise NotImplementedError( - "`from_texts` is not supported. " - "Use `add_texts` to add to existing direct-access index." - ) - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[Dict]] = None, - ids: Optional[List[Any]] = None, - **kwargs: Any, - ) -> List[str]: - """Add texts to the index. - - Only support direct-access index. - - Args: - texts: List of texts to add. - metadatas: List of metadata for each text. Defaults to None. - ids: List of ids for each text. Defaults to None. - If not provided, a random uuid will be generated for each text. - - Returns: - List of ids from adding the texts into the index. - """ - self._op_require_direct_access_index("add_texts") - assert self.embeddings is not None, "embedding model is required." - # Wrap to list if input texts is a single string - if isinstance(texts, str): - texts = [texts] - texts = list(texts) - vectors = self.embeddings.embed_documents(texts) - ids = ids or [str(uuid.uuid4()) for _ in texts] - metadatas = metadatas or [{} for _ in texts] - - updates = [ - { - self.primary_key: id_, - self.text_column: text, - self._embedding_vector_column_name(): vector, - **metadata, - } - for text, vector, id_, metadata in zip(texts, vectors, ids, metadatas) - ] - - upsert_resp = self.index.upsert(updates) - if upsert_resp.get("status") in ("PARTIAL_SUCCESS", "FAILURE"): - failed_ids = upsert_resp.get("result", dict()).get( - "failed_primary_keys", [] - ) - if upsert_resp.get("status") == "FAILURE": - logger.error("Failed to add texts to the index.") - else: - logger.warning("Some texts failed to be added to the index.") - return [id_ for id_ in ids if id_ not in failed_ids] - - return ids - - @property - def embeddings(self) -> Optional[Embeddings]: - """Access the query embedding object if available.""" - return self._embedding - - def delete(self, ids: Optional[List[Any]] = None, **kwargs: Any) -> Optional[bool]: - """Delete documents from the index. - - Only support direct-access index. - - Args: - ids: List of ids of documents to delete. - - Returns: - True if successful. - """ - self._op_require_direct_access_index("delete") - if ids is None: - raise ValueError("ids must be provided.") - self.index.delete(ids) - return True - - def similarity_search( - self, - query: str, - k: int = 4, - filter: Optional[Dict[str, Any]] = None, - *, - query_type: Optional[str] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filters to apply to the query. Defaults to None. - query_type: The type of this query. Supported values are "ANN" and "HYBRID". - - Returns: - List of Documents most similar to the embedding. - """ - docs_with_score = self.similarity_search_with_score( - query=query, - k=k, - filter=filter, - query_type=query_type, - **kwargs, - ) - return [doc for doc, _ in docs_with_score] - - def similarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[Dict[str, Any]] = None, - *, - query_type: Optional[str] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs most similar to query, along with scores. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filters to apply to the query. Defaults to None. - query_type: The type of this query. Supported values are "ANN" and "HYBRID". - - Returns: - List of Documents most similar to the embedding and score for each. - """ - if self._is_databricks_managed_embeddings(): - query_text = query - query_vector = None - else: - assert self.embeddings is not None, "embedding model is required." - # The value for `query_text` needs to be specified only for hybrid search. - if query_type is not None and query_type.upper() == "HYBRID": - query_text = query - else: - query_text = None - query_vector = self.embeddings.embed_query(query) - search_resp = self.index.similarity_search( - columns=self.columns, - query_text=query_text, - query_vector=query_vector, - filters=filter or _alias_filters(kwargs), - num_results=k, - query_type=query_type, - ) - return self._parse_search_response(search_resp) - - @staticmethod - def _identity_fn(score: float) -> float: - return score - - def _select_relevance_score_fn(self) -> Callable[[float], float]: - """ - Databricks Vector search uses a normalized score 1/(1+d) where d - is the L2 distance. Hence, we simply return the identity function. - """ - - return self._identity_fn - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, Any]] = None, - *, - query_type: Optional[str] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter: Filters to apply to the query. Defaults to None. - query_type: The type of this query. Supported values are "ANN" and "HYBRID". - Returns: - List of Documents selected by maximal marginal relevance. - """ - if not self._is_databricks_managed_embeddings(): - assert self.embeddings is not None, "embedding model is required." - query_vector = self.embeddings.embed_query(query) - else: - raise ValueError( - "`max_marginal_relevance_search` is not supported for index with " - "Databricks-managed embeddings." - ) - - docs = self.max_marginal_relevance_search_by_vector( - query_vector, - k, - fetch_k, - lambda_mult=lambda_mult, - filter=filter or _alias_filters(kwargs), - query_type=query_type, - ) - return docs - - def max_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Any] = None, - *, - query_type: Optional[str] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter: Filters to apply to the query. Defaults to None. - query_type: The type of this query. Supported values are "ANN" and "HYBRID". - Returns: - List of Documents selected by maximal marginal relevance. - """ - if not self._is_databricks_managed_embeddings(): - embedding_column = self._embedding_vector_column_name() - else: - raise ValueError( - "`max_marginal_relevance_search` is not supported for index with " - "Databricks-managed embeddings." - ) - search_resp = self.index.similarity_search( - columns=list(set(self.columns + [embedding_column])), - query_text=None, - query_vector=embedding, - filters=filter or _alias_filters(kwargs), - num_results=fetch_k, - query_type=query_type, - ) - - embeddings_result_index = ( - search_resp.get("manifest").get("columns").index({"name": embedding_column}) - ) - embeddings = [ - doc[embeddings_result_index] - for doc in search_resp.get("result").get("data_array") - ] - - mmr_selected = maximal_marginal_relevance( - np.array(embedding, dtype=np.float32), - embeddings, - k=k, - lambda_mult=lambda_mult, - ) - - ignore_cols: List = ( - [embedding_column] if embedding_column not in self.columns else [] - ) - candidates = self._parse_search_response(search_resp, ignore_cols=ignore_cols) - selected_results = [r[0] for i, r in enumerate(candidates) if i in mmr_selected] - return selected_results - - def similarity_search_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[Any] = None, - *, - query_type: Optional[str] = None, - query: Optional[str] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filters to apply to the query. Defaults to None. - query_type: The type of this query. Supported values are "ANN" and "HYBRID". - - Returns: - List of Documents most similar to the embedding. - """ - docs_with_score = self.similarity_search_by_vector_with_score( - embedding=embedding, - k=k, - filter=filter, - query_type=query_type, - query=query, - **kwargs, - ) - return [doc for doc, _ in docs_with_score] - - def similarity_search_by_vector_with_score( - self, - embedding: List[float], - k: int = 4, - filter: Optional[Any] = None, - *, - query_type: Optional[str] = None, - query: Optional[str] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs most similar to embedding vector, along with scores. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filters to apply to the query. Defaults to None. - query_type: The type of this query. Supported values are "ANN" and "HYBRID". - - Returns: - List of Documents most similar to the embedding and score for each. - """ - if self._is_databricks_managed_embeddings(): - raise ValueError( - "`similarity_search_by_vector` is not supported for index with " - "Databricks-managed embeddings." - ) - if query_type is not None and query_type.upper() == "HYBRID": - if query is None: - raise ValueError( - "A value for `query` must be specified for hybrid search." - ) - query_text = query - else: - if query is not None: - raise ValueError( - ( - "Cannot specify both `embedding` and " - '`query` unless `query_type="HYBRID"' - ) - ) - query_text = None - search_resp = self.index.similarity_search( - columns=self.columns, - query_vector=embedding, - query_text=query_text, - filters=filter or _alias_filters(kwargs), - num_results=k, - query_type=query_type, - ) - return self._parse_search_response(search_resp) - - def _parse_search_response( - self, search_resp: Dict, ignore_cols: Optional[List[str]] = None - ) -> List[Tuple[Document, float]]: - """Parse the search response into a list of `Document` objects with score.""" - if ignore_cols is None: - ignore_cols = [] - - columns = [ - col["name"] - for col in search_resp.get("manifest", dict()).get("columns", []) - ] - docs_with_score = [] - for result in search_resp.get("result", dict()).get("data_array", []): - doc_id = result[columns.index(self.primary_key)] - text_content = result[columns.index(self.text_column)] - metadata = { - col: value - for col, value in zip(columns[:-1], result[:-1]) - if col not in ([self.primary_key, self.text_column] + ignore_cols) - } - metadata[self.primary_key] = doc_id - score = result[-1] - doc = Document(page_content=text_content, metadata=metadata) - docs_with_score.append((doc, score)) - return docs_with_score - - def _index_schema(self) -> Optional[Dict]: - """Return the index schema as a dictionary. - Return None if no schema found. - """ - if self._is_direct_access_index(): - schema_json = self._direct_access_index_spec.get("schema_json") - if schema_json is not None: - return json.loads(schema_json) - return None - - def _embedding_vector_column_name(self) -> Optional[str]: - """Return the name of the embedding vector column. - None if the index is not a self-managed embedding index. - """ - return self._embedding_vector_column().get("name") - - def _embedding_vector_column_dimension(self) -> Optional[int]: - """Return the dimension of the embedding vector column. - None if the index is not a self-managed embedding index. - """ - return self._embedding_vector_column().get("embedding_dimension") - - def _embedding_vector_column(self) -> Dict: - """Return the embedding vector column configs as a dictionary. - Empty if the index is not a self-managed embedding index. - """ - index_spec = ( - self._delta_sync_index_spec - if self._is_delta_sync_index() - else self._direct_access_index_spec - ) - return next(iter(index_spec.get("embedding_vector_columns") or list()), dict()) - - def _embedding_source_column_name(self) -> Optional[str]: - """Return the name of the embedding source column. - None if the index is not a Databricks-managed embedding index. - """ - return self._embedding_source_column().get("name") - - def _embedding_source_column(self) -> Dict: - """Return the embedding source column configs as a dictionary. - Empty if the index is not a Databricks-managed embedding index. - """ - index_spec = self._delta_sync_index_spec - return next(iter(index_spec.get("embedding_source_columns") or list()), dict()) - - def _is_delta_sync_index(self) -> bool: - """Return True if the index is a delta-sync index.""" - return self.index_type == "DELTA_SYNC" - - def _is_direct_access_index(self) -> bool: - """Return True if the index is a direct-access index.""" - return self.index_type == "DIRECT_ACCESS" - - def _is_databricks_managed_embeddings(self) -> bool: - """Return True if the embeddings are managed by Databricks Vector Search.""" - return ( - self._is_delta_sync_index() - and self._embedding_source_column_name() is not None - ) - - def _infer_embedding_dimension(self) -> int: - """Infer the embedding dimension from the embedding function.""" - assert self.embeddings is not None, "embedding model is required." - return len(self.embeddings.embed_query("test")) - - def _op_require_direct_access_index(self, op_name: str) -> None: - """ - Raise ValueError if the operation is not supported for direct-access index.""" - if not self._is_direct_access_index(): - raise ValueError(f"`{op_name}` is only supported for direct-access index.") - - @staticmethod - def _require_arg(arg: Any, arg_name: str) -> None: - """Raise ValueError if the required arg with name `arg_name` is None.""" - if not arg: - raise ValueError(f"`{arg_name}` is required for this index.") - - -def _alias_filters(kwargs: Dict[str, Any]) -> Optional[Dict[str, Any]]: - """ - The `filters` argument was used in the previous versions. It is now - replaced with `filter` for consistency with other vector stores, but - we still support `filters` for backward compatibility. - """ - if "filters" in kwargs: - warn_deprecated( - since="0.2.11", - removal="1.0", - message="DatabricksVectorSearch received a key `filters` in search_kwargs. " - "`filters` was deprecated since langchain-community 0.2.11 and will " - "be removed in 0.3. Please use `filter` instead.", - ) - return kwargs.pop("filters", None) diff --git a/libs/community/langchain_community/vectorstores/deeplake.py b/libs/community/langchain_community/vectorstores/deeplake.py deleted file mode 100644 index 232f86a27..000000000 --- a/libs/community/langchain_community/vectorstores/deeplake.py +++ /dev/null @@ -1,970 +0,0 @@ -from __future__ import annotations - -import logging -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union - -import numpy as np - -try: - import deeplake - from deeplake import VectorStore as DeepLakeVectorStore - from deeplake.core.fast_forwarding import version_compare - from deeplake.util.exceptions import SampleExtendError - - _DEEPLAKE_INSTALLED = True -except ImportError: - _DEEPLAKE_INSTALLED = False - -from langchain_core._api import deprecated -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from langchain_core.vectorstores import VectorStore - -from langchain_community.vectorstores.utils import maximal_marginal_relevance - -logger = logging.getLogger(__name__) - - -@deprecated( - since="0.3.3", - removal="1.0", - message=( - "This class is deprecated and will be removed in a future version. " - "You can swap to using the `DeeplakeVectorStore`" - " implementation in `langchain-deeplake`. " - "Please do not submit further PRs to this class." - "See " - ), - alternative_import="langchain_deeplake.DeeplakeVectorStore", -) -class DeepLake(VectorStore): - """`Activeloop Deep Lake` vector store. - - We integrated deeplake's similarity search and filtering for fast prototyping. - Now, it supports Tensor Query Language (TQL) for production use cases - over billion rows. - - Why Deep Lake? - - - Not only stores embeddings, but also the original data with version control. - - Serverless, doesn't require another service and can be used with major - cloud providers (S3, GCS, etc.) - - More than just a multi-modal vector store. You can use the dataset - to fine-tune your own LLM models. - - To use, you should have the ``deeplake`` python package installed. - - Example: - .. code-block:: python - - from langchain_community.vectorstores import DeepLake - from langchain_community.embeddings.openai import OpenAIEmbeddings - - embeddings = OpenAIEmbeddings() - vectorstore = DeepLake("langchain_store", embeddings.embed_query) - """ - - _LANGCHAIN_DEFAULT_DEEPLAKE_PATH: str = "./deeplake/" - _valid_search_kwargs = ["lambda_mult"] - - def __init__( - self, - dataset_path: str = _LANGCHAIN_DEFAULT_DEEPLAKE_PATH, - token: Optional[str] = None, - embedding: Optional[Embeddings] = None, - embedding_function: Optional[Embeddings] = None, - read_only: bool = False, - ingestion_batch_size: int = 1024, - num_workers: int = 0, - verbose: bool = True, - exec_option: Optional[str] = None, - runtime: Optional[Dict] = None, - index_params: Optional[Dict[str, Union[int, str]]] = None, - **kwargs: Any, - ) -> None: - """Creates an empty DeepLakeVectorStore or loads an existing one. - - The DeepLakeVectorStore is located at the specified ``path``. - - Examples: - >>> # Create a vector store with default tensors - >>> deeplake_vectorstore = DeepLake( - ... path = , - ... ) - >>> - >>> # Create a vector store in the Deep Lake Managed Tensor Database - >>> data = DeepLake( - ... path = "hub://org_id/dataset_name", - ... runtime = {"tensor_db": True}, - ... ) - - Args: - dataset_path (str): The full path for storing to the Deep Lake - Vector Store. It can be: - - a Deep Lake cloud path of the form ``hub://org_id/dataset_name``. - Requires registration with Deep Lake. - - an s3 path of the form ``s3://bucketname/path/to/dataset``. - Credentials are required in either the environment or passed to - the creds argument. - - a local file system path of the form ``./path/to/dataset`` - or ``~/path/to/dataset`` or ``path/to/dataset``. - - a memory path of the form ``mem://path/to/dataset`` which doesn't - save the dataset but keeps it in memory instead. - Should be used only for testing as it does not persist. - Defaults to _LANGCHAIN_DEFAULT_DEEPLAKE_PATH. - token (str, optional): Activeloop token, for fetching credentials - to the dataset at path if it is a Deep Lake dataset. - Tokens are normally autogenerated. Optional. - embedding (Embeddings, optional): Function to convert - either documents or query. Optional. - embedding_function (Embeddings, optional): Function to convert - either documents or query. Optional. Deprecated: keeping this - parameter for backwards compatibility. - read_only (bool): Open dataset in read-only mode. Default is False. - ingestion_batch_size (int): During data ingestion, data is divided - into batches. Batch size is the size of each batch. - Default is 1024. - num_workers (int): Number of workers to use during data ingestion. - Default is 0. - verbose (bool): Print dataset summary after each operation. - Default is True. - exec_option (str, optional): Default method for search execution. - It could be either ``"auto"``, ``"python"``, ``"compute_engine"`` - or ``"tensor_db"``. Defaults to ``"auto"``. - If None, it's set to "auto". - - ``auto``- Selects the best execution method based on the storage - location of the Vector Store. It is the default option. - - ``python`` - Pure-python implementation that runs on the client and - can be used for data stored anywhere. WARNING: using this option - with big datasets is discouraged because it can lead to - memory issues. - - ``compute_engine`` - Performant C++ implementation of the Deep Lake - Compute Engine that runs on the client and can be used for any data - stored in or connected to Deep Lake. It cannot be used with - in-memory or local datasets. - - ``tensor_db`` - Performant and fully-hosted Managed Tensor Database - that is responsible for storage and query execution. Only available - for data stored in the Deep Lake Managed Database. Store datasets - in this database by specifying runtime = {"tensor_db": True} - during dataset creation. - runtime (Dict, optional): Parameters for creating the Vector Store in - Deep Lake's Managed Tensor Database. Not applicable when loading an - existing Vector Store. To create a Vector Store in the Managed Tensor - Database, set `runtime = {"tensor_db": True}`. - index_params (Optional[Dict[str, Union[int, str]]], optional): Dictionary - containing information about vector index that will be created. Defaults - to None, which will utilize ``DEFAULT_VECTORSTORE_INDEX_PARAMS`` from - ``deeplake.constants``. The specified key-values override the default - ones. - - threshold: The threshold for the dataset size above which an index - will be created for the embedding tensor. When the threshold value - is set to -1, index creation is turned off. Defaults to -1, which - turns off the index. - - distance_metric: This key specifies the method of calculating the - distance between vectors when creating the vector database (VDB) - index. It can either be a string that corresponds to a member of - the DistanceType enumeration, or the string value itself. - - If no value is provided, it defaults to "L2". - - "L2" corresponds to DistanceType.L2_NORM. - - "COS" corresponds to DistanceType.COSINE_SIMILARITY. - - additional_params: Additional parameters for fine-tuning the index. - **kwargs: Other optional keyword arguments. - - Raises: - ValueError: If some condition is not met. - """ - - self.ingestion_batch_size = ingestion_batch_size - self.num_workers = num_workers - self.verbose = verbose - - if _DEEPLAKE_INSTALLED is False: - raise ImportError( - "Could not import deeplake python package. " - "Please install it with `pip install deeplake[enterprise]<4.0.0`." - ) - - if ( - runtime == {"tensor_db": True} - and version_compare(deeplake.__version__, "3.6.7") == -1 - ): - raise ImportError( - "To use tensor_db option you need to update deeplake to `3.6.7` or " - "higher. " - f"Currently installed deeplake version is {deeplake.__version__}. " - ) - - self.dataset_path = dataset_path - - if embedding_function: - logger.warning( - "Using embedding function is deprecated and will be removed " - "in the future. Please use embedding instead." - ) - - self.vectorstore = DeepLakeVectorStore( - path=self.dataset_path, - embedding_function=embedding_function or embedding, - read_only=read_only, - token=token, - exec_option=exec_option, - verbose=verbose, - runtime=runtime, - index_params=index_params, - **kwargs, - ) - - self._embedding_function = embedding_function or embedding - self._id_tensor_name = "ids" if "ids" in self.vectorstore.tensors() else "id" - - @property - def embeddings(self) -> Optional[Embeddings]: - return self._embedding_function - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, - ) -> List[str]: - """Run more texts through the embeddings and add to the vectorstore. - - Examples: - >>> ids = deeplake_vectorstore.add_texts( - ... texts = , - ... metadatas = , - ... ids = , - ... ) - - Args: - texts (Iterable[str]): Texts to add to the vectorstore. - metadatas (Optional[List[dict]], optional): Optional list of metadatas. - ids (Optional[List[str]], optional): Optional list of IDs. - embedding_function (Optional[Embeddings], optional): Embedding function - to use to convert the text into embeddings. - **kwargs (Any): Any additional keyword arguments passed is not supported - by this method. - - Returns: - List[str]: List of IDs of the added texts. - """ - self._validate_kwargs(kwargs, "add_texts") - - kwargs = {} - if ids: - if self._id_tensor_name == "ids": # for backwards compatibility - kwargs["ids"] = ids - else: - kwargs["id"] = ids - - if metadatas is None: - metadatas = [{}] * len(list(texts)) - - if not isinstance(texts, list): - texts = list(texts) - - if texts is None: - raise ValueError("`texts` parameter shouldn't be None.") - elif len(texts) == 0: - raise ValueError("`texts` parameter shouldn't be empty.") - - try: - return self.vectorstore.add( - text=texts, - metadata=metadatas, - embedding_data=texts, - embedding_tensor="embedding", - embedding_function=self._embedding_function.embed_documents, # type: ignore[union-attr] - return_ids=True, - **kwargs, - ) - except SampleExtendError as e: - if "Failed to append a sample to the tensor 'metadata'" in str(e): - msg = ( - "**Hint: You might be using invalid type of argument in " - "document loader (e.g. 'pathlib.PosixPath' instead of 'str')" - ) - raise ValueError(e.args[0] + "\n\n" + msg) - else: - raise e - - def _search_tql( - self, - tql: Optional[str], - exec_option: Optional[str] = None, - **kwargs: Any, - ) -> List[Document]: - """Function for performing tql_search. - - Args: - tql (str): TQL Query string for direct evaluation. - Available only for `compute_engine` and `tensor_db`. - exec_option (str, optional): Supports 3 ways to search. - Could be "python", "compute_engine" or "tensor_db". Default is "python". - - ``python`` - Pure-python implementation for the client. - WARNING: not recommended for big datasets due to potential memory - issues. - - ``compute_engine`` - C++ implementation of Deep Lake Compute - Engine for the client. Not for in-memory or local datasets. - - ``tensor_db`` - Hosted Managed Tensor Database for storage - and query execution. Only for data in Deep Lake Managed Database. - Use runtime = {"db_engine": True} during dataset creation. - return_score (bool): Return score with document. Default is False. - - Returns: - Tuple[List[Document], List[Tuple[Document, float]]] - A tuple of two lists. - The first list contains Documents, and the second list contains - tuples of Document and float score. - - Raises: - ValueError: If return_score is True but some condition is not met. - """ - result = self.vectorstore.search( - query=tql, - exec_option=exec_option, - ) - metadatas = result["metadata"] - texts = result["text"] - - docs = [ - Document( - page_content=text, - metadata=metadata, - ) - for text, metadata in zip(texts, metadatas) - ] - - if kwargs: - unsupported_argument = next(iter(kwargs)) - if kwargs[unsupported_argument] is not False: - raise ValueError( - f"specifying {unsupported_argument} is " - "not supported with tql search." - ) - - return docs - - def _search( - self, - query: Optional[str] = None, - embedding: Optional[Union[List[float], np.ndarray]] = None, - embedding_function: Optional[Callable] = None, - k: int = 4, - distance_metric: Optional[str] = None, - use_maximal_marginal_relevance: bool = False, - fetch_k: Optional[int] = 20, - filter: Optional[Union[Dict, Callable]] = None, - return_score: bool = False, - exec_option: Optional[str] = None, - deep_memory: bool = False, - **kwargs: Any, - ) -> Any[List[Document], List[Tuple[Document, float]]]: - """ - Return docs similar to query. - - Args: - query (str, optional): Text to look up similar docs. - embedding (Union[List[float], np.ndarray], optional): Query's embedding. - embedding_function (Callable, optional): Function to convert `query` - into embedding. - k (int): Number of Documents to return. - distance_metric (Optional[str], optional): `L2` for Euclidean, `L1` for - Nuclear, `max` for L-infinity distance, `cos` for cosine similarity, - 'dot' for dot product. - filter (Union[Dict, Callable], optional): Additional filter prior - to the embedding search. - - ``Dict`` - Key-value search on tensors of htype json, on an - AND basis (a sample must satisfy all key-value filters to be True) - Dict = {"tensor_name_1": {"key": value}, - "tensor_name_2": {"key": value}} - - ``Function`` - Any function compatible with `deeplake.filter`. - use_maximal_marginal_relevance (bool): Use maximal marginal relevance. - fetch_k (int): Number of Documents for MMR algorithm. - return_score (bool): Return the score. - exec_option (str, optional): Supports 3 ways to perform searching. - Could be "python", "compute_engine" or "tensor_db". - - ``python`` - Pure-python implementation for the client. - WARNING: not recommended for big datasets. - - ``compute_engine`` - C++ implementation of Deep Lake Compute - Engine for the client. Not for in-memory or local datasets. - - ``tensor_db`` - Hosted Managed Tensor Database for storage - and query execution. Only for data in Deep Lake Managed Database. - Use runtime = {"db_engine": True} during dataset creation. - deep_memory (bool): Whether to use the Deep Memory model for improving - search results. Defaults to False if deep_memory is not specified in - the Vector Store initialization. If True, the distance metric is set - to "deepmemory_distance", which represents the metric with which the - model was trained. The search is performed using the Deep Memory model. - If False, the distance metric is set to "COS" or whatever distance - metric user specifies. - kwargs: Additional keyword arguments. - - Returns: - List of Documents by the specified distance metric, - if return_score True, return a tuple of (Document, score) - - Raises: - ValueError: if both `embedding` and `embedding_function` are not specified. - """ - if kwargs.get("tql_query"): - logger.warning("`tql_query` is deprecated. Please use `tql` instead.") - kwargs["tql"] = kwargs.pop("tql_query") - - if kwargs.get("tql"): - return self._search_tql( - tql=kwargs["tql"], - exec_option=exec_option, - return_score=return_score, - embedding=embedding, - embedding_function=embedding_function, - distance_metric=distance_metric, - use_maximal_marginal_relevance=use_maximal_marginal_relevance, - filter=filter, - ) - - self._validate_kwargs(kwargs, "search") - - if embedding_function: - if isinstance(embedding_function, Embeddings): - _embedding_function = embedding_function.embed_query - else: - _embedding_function = embedding_function - elif self._embedding_function: - _embedding_function = self._embedding_function.embed_query - else: - _embedding_function = None - - if embedding is None: - if _embedding_function is None: - raise ValueError( - "Either `embedding` or `embedding_function` needs to be specified." - ) - - embedding = _embedding_function(query) if query else None - - if isinstance(embedding, list): - embedding = np.array(embedding, dtype=np.float32) - if len(embedding.shape) > 1: - embedding = embedding[0] - - result = self.vectorstore.search( - embedding=embedding, - k=fetch_k if use_maximal_marginal_relevance else k, - distance_metric=distance_metric, - filter=filter, - exec_option=exec_option, - return_tensors=["embedding", "metadata", "text", self._id_tensor_name], - deep_memory=deep_memory, - ) - scores = result["score"] - embeddings = result["embedding"] - metadatas = result["metadata"] - texts = result["text"] - - if use_maximal_marginal_relevance: - lambda_mult = kwargs.get("lambda_mult", 0.5) - indices = maximal_marginal_relevance( - embedding, # type: ignore[arg-type] - embeddings, - k=min(k, len(texts)), - lambda_mult=lambda_mult, - ) - - scores = [scores[i] for i in indices] - texts = [texts[i] for i in indices] - metadatas = [metadatas[i] for i in indices] - - docs = [ - Document( - page_content=text, - metadata=metadata, - ) - for text, metadata in zip(texts, metadatas) - ] - - if return_score: - if not isinstance(scores, list): - scores = [scores] - - return [(doc, score) for doc, score in zip(docs, scores)] - - return docs - - def similarity_search( - self, - query: str, - k: int = 4, - **kwargs: Any, - ) -> List[Document]: - """ - Return docs most similar to query. - - Examples: - >>> # Search using an embedding - >>> data = vector_store.similarity_search( - ... query=, - ... k=, - ... exec_option=, - ... ) - >>> # Run tql search: - >>> data = vector_store.similarity_search( - ... query=None, - ... tql="SELECT * WHERE id == ", - ... exec_option="compute_engine", - ... ) - - Args: - k (int): Number of Documents to return. Defaults to 4. - query (str): Text to look up similar documents. - kwargs: Additional keyword arguments include: - embedding (Callable): Embedding function to use. Defaults to None. - distance_metric (str): 'L2' for Euclidean, 'L1' for Nuclear, 'max' - for L-infinity, 'cos' for cosine, 'dot' for dot product. - Defaults to 'L2'. - filter (Union[Dict, Callable], optional): Additional filter - before embedding search. - - Dict: Key-value search on tensors of htype json, - (sample must satisfy all key-value filters) - Dict = {"tensor_1": {"key": value}, "tensor_2": {"key": value}} - - Function: Compatible with `deeplake.filter`. - Defaults to None. - exec_option (str): Supports 3 ways to perform searching. - 'python', 'compute_engine', or 'tensor_db'. Defaults to 'python'. - - 'python': Pure-python implementation for the client. - WARNING: not recommended for big datasets. - - 'compute_engine': C++ implementation of the Compute Engine for - the client. Not for in-memory or local datasets. - - 'tensor_db': Managed Tensor Database for storage and query. - Only for data in Deep Lake Managed Database. - Use `runtime = {"db_engine": True}` during dataset creation. - deep_memory (bool): Whether to use the Deep Memory model for improving - search results. Defaults to False if deep_memory is not specified - in the Vector Store initialization. If True, the distance metric - is set to "deepmemory_distance", which represents the metric with - which the model was trained. The search is performed using the Deep - Memory model. If False, the distance metric is set to "COS" or - whatever distance metric user specifies. - - Returns: - List[Document]: List of Documents most similar to the query vector. - """ - - return self._search( - query=query, - k=k, - use_maximal_marginal_relevance=False, - return_score=False, - **kwargs, - ) - - def similarity_search_by_vector( - self, - embedding: Union[List[float], np.ndarray], - k: int = 4, - **kwargs: Any, - ) -> List[Document]: - """ - Return docs most similar to embedding vector. - - Examples: - >>> # Search using an embedding - >>> data = vector_store.similarity_search_by_vector( - ... embedding=, - ... k=, - ... exec_option=, - ... ) - - Args: - embedding (Union[List[float], np.ndarray]): - Embedding to find similar docs. - k (int): Number of Documents to return. Defaults to 4. - kwargs: Additional keyword arguments including: - filter (Union[Dict, Callable], optional): - Additional filter before embedding search. - - ``Dict`` - Key-value search on tensors of htype json. True - if all key-value filters are satisfied. - Dict = {"tensor_name_1": {"key": value}, - "tensor_name_2": {"key": value}} - - ``Function`` - Any function compatible with - `deeplake.filter`. - Defaults to None. - exec_option (str): Options for search execution include - "python", "compute_engine", or "tensor_db". Defaults to - "python". - - "python" - Pure-python implementation running on the client. - Can be used for data stored anywhere. WARNING: using this - option with big datasets is discouraged due to potential - memory issues. - - "compute_engine" - Performant C++ implementation of the Deep - Lake Compute Engine. Runs on the client and can be used for - any data stored in or connected to Deep Lake. It cannot be - used with in-memory or local datasets. - - "tensor_db" - Performant, fully-hosted Managed Tensor Database. - Responsible for storage and query execution. Only available - for data stored in the Deep Lake Managed Database. - To store datasets in this database, specify - `runtime = {"db_engine": True}` during dataset creation. - distance_metric (str): `L2` for Euclidean, `L1` for Nuclear, - `max` for L-infinity distance, `cos` for cosine similarity, - 'dot' for dot product. Defaults to `L2`. - deep_memory (bool): Whether to use the Deep Memory model for improving - search results. Defaults to False if deep_memory is not specified - in the Vector Store initialization. If True, the distance metric - is set to "deepmemory_distance", which represents the metric with - which the model was trained. The search is performed using the Deep - Memory model. If False, the distance metric is set to "COS" or - whatever distance metric user specifies. - - Returns: - List[Document]: List of Documents most similar to the query vector. - """ - - return self._search( - embedding=embedding, - k=k, - use_maximal_marginal_relevance=False, - return_score=False, - **kwargs, - ) - - def similarity_search_with_score( - self, - query: str, - k: int = 4, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """ - Run similarity search with Deep Lake with distance returned. - - Examples: - >>> data = vector_store.similarity_search_with_score( - ... query=, - ... embedding= - ... k=, - ... exec_option=, - ... ) - - Args: - query (str): Query text to search for. - k (int): Number of results to return. Defaults to 4. - kwargs: Additional keyword arguments. Some of these arguments are: - distance_metric: `L2` for Euclidean, `L1` for Nuclear, `max` L-infinity - distance, `cos` for cosine similarity, 'dot' for dot product. - Defaults to `L2`. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - embedding_function (Callable): Embedding function to use. Defaults - to None. - exec_option (str): DeepLakeVectorStore supports 3 ways to perform - searching. It could be either "python", "compute_engine" or - "tensor_db". Defaults to "python". - - "python" - Pure-python implementation running on the client. - Can be used for data stored anywhere. WARNING: using this - option with big datasets is discouraged due to potential - memory issues. - - "compute_engine" - Performant C++ implementation of the Deep - Lake Compute Engine. Runs on the client and can be used for - any data stored in or connected to Deep Lake. It cannot be used - with in-memory or local datasets. - - "tensor_db" - Performant, fully-hosted Managed Tensor Database. - Responsible for storage and query execution. Only available for - data stored in the Deep Lake Managed Database. To store datasets - in this database, specify `runtime = {"db_engine": True}` - during dataset creation. - deep_memory (bool): Whether to use the Deep Memory model for improving - search results. Defaults to False if deep_memory is not specified - in the Vector Store initialization. If True, the distance metric - is set to "deepmemory_distance", which represents the metric with - which the model was trained. The search is performed using the Deep - Memory model. If False, the distance metric is set to "COS" or - whatever distance metric user specifies. - - Returns: - List[Tuple[Document, float]]: List of documents most similar to the query - text with distance in float.""" - - return self._search( - query=query, - k=k, - return_score=True, - **kwargs, - ) - - def max_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - exec_option: Optional[str] = None, - **kwargs: Any, - ) -> List[Document]: - """ - Return docs selected using the maximal marginal relevance. Maximal marginal - relevance optimizes for similarity to query AND diversity among selected docs. - - Examples: - >>> data = vector_store.max_marginal_relevance_search_by_vector( - ... embedding=, - ... fetch_k=, - ... k=, - ... exec_option=, - ... ) - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch for MMR algorithm. - lambda_mult: Number between 0 and 1 determining the degree of diversity. - 0 corresponds to max diversity and 1 to min diversity. Defaults to 0.5. - exec_option (str): DeepLakeVectorStore supports 3 ways for searching. - Could be "python", "compute_engine" or "tensor_db". Defaults to - "python". - - "python" - Pure-python implementation running on the client. - Can be used for data stored anywhere. WARNING: using this - option with big datasets is discouraged due to potential - memory issues. - - "compute_engine" - Performant C++ implementation of the Deep - Lake Compute Engine. Runs on the client and can be used for - any data stored in or connected to Deep Lake. It cannot be used - with in-memory or local datasets. - - "tensor_db" - Performant, fully-hosted Managed Tensor Database. - Responsible for storage and query execution. Only available for - data stored in the Deep Lake Managed Database. To store datasets - in this database, specify `runtime = {"db_engine": True}` - during dataset creation. - deep_memory (bool): Whether to use the Deep Memory model for improving - search results. Defaults to False if deep_memory is not specified - in the Vector Store initialization. If True, the distance metric - is set to "deepmemory_distance", which represents the metric with - which the model was trained. The search is performed using the Deep - Memory model. If False, the distance metric is set to "COS" or - whatever distance metric user specifies. - kwargs: Additional keyword arguments. - - Returns: - List[Documents] - A list of documents. - """ - - return self._search( - embedding=embedding, - k=k, - fetch_k=fetch_k, - use_maximal_marginal_relevance=True, - lambda_mult=lambda_mult, - exec_option=exec_option, - **kwargs, - ) - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - exec_option: Optional[str] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Examples: - >>> # Search using an embedding - >>> data = vector_store.max_marginal_relevance_search( - ... query = , - ... embedding_function = , - ... k = , - ... exec_option = , - ... ) - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents for MMR algorithm. - lambda_mult: Value between 0 and 1. 0 corresponds - to maximum diversity and 1 to minimum. - Defaults to 0.5. - exec_option (str): Supports 3 ways to perform searching. - - "python" - Pure-python implementation running on the client. - Can be used for data stored anywhere. WARNING: using this - option with big datasets is discouraged due to potential - memory issues. - - "compute_engine" - Performant C++ implementation of the Deep - Lake Compute Engine. Runs on the client and can be used for - any data stored in or connected to Deep Lake. It cannot be - used with in-memory or local datasets. - - "tensor_db" - Performant, fully-hosted Managed Tensor Database. - Responsible for storage and query execution. Only available - for data stored in the Deep Lake Managed Database. To store - datasets in this database, specify - `runtime = {"db_engine": True}` during dataset creation. - deep_memory (bool): Whether to use the Deep Memory model for improving - search results. Defaults to False if deep_memory is not specified - in the Vector Store initialization. If True, the distance metric - is set to "deepmemory_distance", which represents the metric with - which the model was trained. The search is performed using the Deep - Memory model. If False, the distance metric is set to "COS" or - whatever distance metric user specifies. - kwargs: Additional keyword arguments - - Returns: - List of Documents selected by maximal marginal relevance. - - Raises: - ValueError: when MRR search is on but embedding function is - not specified. - """ - embedding_function = kwargs.get("embedding") or self._embedding_function - if embedding_function is None: - raise ValueError( - "For MMR search, you must specify an embedding function on" - " `creation` or during add call." - ) - return self._search( - query=query, - k=k, - fetch_k=fetch_k, - use_maximal_marginal_relevance=True, - lambda_mult=lambda_mult, - exec_option=exec_option, - embedding_function=embedding_function, # type: ignore[arg-type] - **kwargs, - ) - - @classmethod - def from_texts( - cls, - texts: List[str], - embedding: Optional[Embeddings] = None, - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - dataset_path: str = _LANGCHAIN_DEFAULT_DEEPLAKE_PATH, - **kwargs: Any, - ) -> DeepLake: - """Create a Deep Lake dataset from a raw documents. - - If a dataset_path is specified, the dataset will be persisted in that location, - otherwise by default at `./deeplake` - - Examples: - >>> # Search using an embedding - >>> vector_store = DeepLake.from_texts( - ... texts = , - ... embedding_function = , - ... k = , - ... exec_option = , - ... ) - - Args: - dataset_path (str): - The full path to the dataset. Can be: - - Deep Lake cloud path of the form ``hub://username/dataset_name``. - To write to Deep Lake cloud datasets, - ensure that you are logged in to Deep Lake - (use 'activeloop login' from command line) - - AWS S3 path of the form ``s3://bucketname/path/to/dataset``. - Credentials are required in either the environment - - Google Cloud Storage path of the form - ``gcs://bucketname/path/to/dataset`` Credentials are required - in either the environment - - Local file system path of the form ``./path/to/dataset`` or - ``~/path/to/dataset`` or ``path/to/dataset``. - - In-memory path of the form ``mem://path/to/dataset`` which doesn't - save the dataset, but keeps it in memory instead. - Should be used only for testing as it does not persist. - texts (List[Document]): List of documents to add. - embedding (Optional[Embeddings]): Embedding function. Defaults to None. - Note, in other places, it is called embedding_function. - metadatas (Optional[List[dict]]): List of metadatas. Defaults to None. - ids (Optional[List[str]]): List of document IDs. Defaults to None. - kwargs: Additional keyword arguments. - - Returns: - DeepLake: Deep Lake dataset. - """ - deeplake_dataset = cls(dataset_path=dataset_path, embedding=embedding, **kwargs) - deeplake_dataset.add_texts( - texts=texts, - metadatas=metadatas, - ids=ids, - ) - return deeplake_dataset - - def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> bool: - """Delete the entities in the dataset. - - Args: - ids (Optional[List[str]], optional): The document_ids to delete. - Defaults to None. - **kwargs: Other keyword arguments that subclasses might use. - - filter (Optional[Dict[str, str]], optional): The filter to delete by. - - delete_all (Optional[bool], optional): Whether to drop the dataset. - - Returns: - bool: Whether the delete operation was successful. - """ - filter = kwargs.get("filter") - delete_all = kwargs.get("delete_all") - - self.vectorstore.delete(ids=ids, filter=filter, delete_all=delete_all) - - return True - - @classmethod - def force_delete_by_path(cls, path: str) -> None: - """Force delete dataset by path. - - Args: - path (str): path of the dataset to delete. - - Raises: - ValueError: if deeplake is not installed. - """ - - try: - import deeplake - except ImportError: - raise ImportError( - "Could not import deeplake python package. " - "Please install it with `pip install deeplake`." - ) - deeplake.delete(path, large_ok=True, force=True) - - def delete_dataset(self) -> None: - """Delete the collection.""" - self.delete(delete_all=True) - - def ds(self) -> Any: - logger.warning( - "this method is deprecated and will be removed, " - "better to use `db.vectorstore.dataset` instead." - ) - return self.vectorstore.dataset - - @classmethod - def _validate_kwargs(cls, kwargs: Any, method_name: str) -> None: - if kwargs: - valid_items = cls._get_valid_args(method_name) - unsupported_items = cls._get_unsupported_items(kwargs, valid_items) - - if unsupported_items: - raise TypeError( - f"`{unsupported_items}` are not a valid " - f"argument to {method_name} method" - ) - - @classmethod - def _get_valid_args(cls, method_name: str) -> list[str]: - if method_name == "search": - return cls._valid_search_kwargs - else: - return [] - - @staticmethod - def _get_unsupported_items(kwargs: Any, valid_items: list[str]) -> Optional[str]: - kwargs = {k: v for k, v in kwargs.items() if k not in valid_items} - unsupported_items = None - if kwargs: - unsupported_items = "`, `".join(set(kwargs.keys())) - return unsupported_items diff --git a/libs/community/langchain_community/vectorstores/hanavector.py b/libs/community/langchain_community/vectorstores/hanavector.py deleted file mode 100644 index a15a31fff..000000000 --- a/libs/community/langchain_community/vectorstores/hanavector.py +++ /dev/null @@ -1,842 +0,0 @@ -"""SAP HANA Cloud Vector Engine""" - -from __future__ import annotations - -import importlib.util -import json -import re -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - List, - Optional, - Pattern, - Tuple, -) - -import numpy as np -from langchain_core._api import deprecated -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from langchain_core.runnables.config import run_in_executor -from langchain_core.vectorstores import VectorStore -from typing_extensions import Self - -from langchain_community.vectorstores.utils import ( - DistanceStrategy, - maximal_marginal_relevance, -) - -if TYPE_CHECKING: - from hdbcli import dbapi - -HANA_DISTANCE_FUNCTION: dict = { - DistanceStrategy.COSINE: ("COSINE_SIMILARITY", "DESC"), - DistanceStrategy.EUCLIDEAN_DISTANCE: ("L2DISTANCE", "ASC"), -} - -COMPARISONS_TO_SQL = { - "$eq": "=", - "$ne": "<>", - "$lt": "<", - "$lte": "<=", - "$gt": ">", - "$gte": ">=", -} - -IN_OPERATORS_TO_SQL = { - "$in": "IN", - "$nin": "NOT IN", -} - -BETWEEN_OPERATOR = "$between" - -LIKE_OPERATOR = "$like" - -LOGICAL_OPERATORS_TO_SQL = {"$and": "AND", "$or": "OR"} - - -default_distance_strategy = DistanceStrategy.COSINE -default_table_name: str = "EMBEDDINGS" -default_content_column: str = "VEC_TEXT" -default_metadata_column: str = "VEC_META" -default_vector_column: str = "VEC_VECTOR" -default_vector_column_length: int = -1 # -1 means dynamic length - - -@deprecated( - since="0.3.23", - removal="1.0", - message=( - "This class is deprecated and will be removed in a future version. " - "Please use HanaDB from the langchain_hana package instead. " - "See https://github.com/SAP/langchain-integration-for-sap-hana-cloud " - "for details." - ), - alternative="from langchain_hana import HanaDB;", - pending=False, -) -class HanaDB(VectorStore): - """SAP HANA Cloud Vector Engine - - **DEPRECATED**: This class is deprecated and will no longer be maintained. - Please use HanaDB from the langchain_hana package instead. It offers an - improved implementation and full support. - - The prerequisite for using this class is the installation of the ``hdbcli`` - Python package. - - The HanaDB vectorstore can be created by providing an embedding function and - an existing database connection. Optionally, the names of the table and the - columns to use. - """ - - def __init__( - self, - connection: dbapi.Connection, - embedding: Embeddings, - distance_strategy: DistanceStrategy = default_distance_strategy, - table_name: str = default_table_name, - content_column: str = default_content_column, - metadata_column: str = default_metadata_column, - vector_column: str = default_vector_column, - vector_column_length: int = default_vector_column_length, - *, - specific_metadata_columns: Optional[List[str]] = None, - ): - # Check if the hdbcli package is installed - if importlib.util.find_spec("hdbcli") is None: - raise ImportError( - "Could not import hdbcli python package. " - "Please install it with `pip install hdbcli`." - ) - - valid_distance = False - for key in HANA_DISTANCE_FUNCTION.keys(): - if key is distance_strategy: - valid_distance = True - if not valid_distance: - raise ValueError( - "Unsupported distance_strategy: {}".format(distance_strategy) - ) - - self.connection = connection - self.embedding = embedding - self.distance_strategy = distance_strategy - self.table_name = HanaDB._sanitize_name(table_name) - self.content_column = HanaDB._sanitize_name(content_column) - self.metadata_column = HanaDB._sanitize_name(metadata_column) - self.vector_column = HanaDB._sanitize_name(vector_column) - self.vector_column_length = HanaDB._sanitize_int(vector_column_length) - self.specific_metadata_columns = HanaDB._sanitize_specific_metadata_columns( - specific_metadata_columns or [] - ) - - # Check if the table exists, and eventually create it - if not self._table_exists(self.table_name): - sql_str = ( - f'CREATE TABLE "{self.table_name}"(' - f'"{self.content_column}" NCLOB, ' - f'"{self.metadata_column}" NCLOB, ' - f'"{self.vector_column}" REAL_VECTOR ' - ) - if self.vector_column_length in [-1, 0]: - sql_str += ");" - else: - sql_str += f"({self.vector_column_length}));" - - try: - cur = self.connection.cursor() - cur.execute(sql_str) - finally: - cur.close() - - # Check if the needed columns exist and have the correct type - self._check_column(self.table_name, self.content_column, ["NCLOB", "NVARCHAR"]) - self._check_column(self.table_name, self.metadata_column, ["NCLOB", "NVARCHAR"]) - self._check_column( - self.table_name, - self.vector_column, - ["REAL_VECTOR"], - self.vector_column_length, - ) - for column_name in self.specific_metadata_columns: - self._check_column(self.table_name, column_name) - - def _table_exists(self, table_name: str) -> bool: - sql_str = ( - "SELECT COUNT(*) FROM SYS.TABLES WHERE SCHEMA_NAME = CURRENT_SCHEMA" - " AND TABLE_NAME = ?" - ) - try: - cur = self.connection.cursor() - cur.execute(sql_str, (table_name)) - if cur.has_result_set(): - rows = cur.fetchall() - if rows[0][0] == 1: - return True - finally: - cur.close() - return False - - def _check_column( - self, - table_name: str, - column_name: str, - column_type: Optional[list[str]] = None, - column_length: Optional[int] = None, - ) -> None: - sql_str = ( - "SELECT DATA_TYPE_NAME, LENGTH FROM SYS.TABLE_COLUMNS WHERE " - "SCHEMA_NAME = CURRENT_SCHEMA " - "AND TABLE_NAME = ? AND COLUMN_NAME = ?" - ) - try: - cur = self.connection.cursor() - cur.execute(sql_str, (table_name, column_name)) - if cur.has_result_set(): - rows = cur.fetchall() - if len(rows) == 0: - raise AttributeError(f"Column {column_name} does not exist") - # Check data type - if column_type: - if rows[0][0] not in column_type: - raise AttributeError( - f"Column {column_name} has the wrong type: {rows[0][0]}" - ) - # Check length, if parameter was provided - # Length can either be -1 (QRC01+02-24) or 0 (QRC03-24 onwards) - # to indicate no length constraint being present. - if column_length is not None and column_length > 0: - if rows[0][1] != column_length: - raise AttributeError( - f"Column {column_name} has the wrong length: {rows[0][1]} " - f"expected: {column_length}" - ) - else: - raise AttributeError(f"Column {column_name} does not exist") - finally: - cur.close() - - @property - def embeddings(self) -> Embeddings: - return self.embedding - - @staticmethod - def _sanitize_name(input_str: str) -> str: - # Remove characters that are not alphanumeric or underscores - return re.sub(r"[^a-zA-Z0-9_]", "", input_str) - - @staticmethod - def _sanitize_int(input_int: any) -> int: # type: ignore[valid-type] - value = int(str(input_int)) - if value < -1: - raise ValueError(f"Value ({value}) must not be smaller than -1") - return int(str(input_int)) - - @staticmethod - def _sanitize_list_float(embedding: List[float]) -> List[float]: - for value in embedding: - if not isinstance(value, float): - raise ValueError(f"Value ({value}) does not have type float") - return embedding - - # Compile pattern only once, for better performance - _compiled_pattern: Pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$") - - @staticmethod - def _sanitize_metadata_keys(metadata: dict) -> dict: - for key in metadata.keys(): - if not HanaDB._compiled_pattern.match(key): - raise ValueError(f"Invalid metadata key {key}") - - return metadata - - @staticmethod - def _sanitize_specific_metadata_columns( - specific_metadata_columns: List[str], - ) -> List[str]: - metadata_columns = [] - for c in specific_metadata_columns: - sanitized_name = HanaDB._sanitize_name(c) - metadata_columns.append(sanitized_name) - return metadata_columns - - def _split_off_special_metadata(self, metadata: dict) -> Tuple[dict, list]: - # Use provided values by default or fallback - special_metadata = [] - - if not metadata: - return {}, [] - - for column_name in self.specific_metadata_columns: - special_metadata.append(metadata.get(column_name, None)) - - return metadata, special_metadata - - def create_hnsw_index( - self, - m: Optional[int] = None, # Optional M parameter - ef_construction: Optional[int] = None, # Optional efConstruction parameter - ef_search: Optional[int] = None, # Optional efSearch parameter - index_name: Optional[str] = None, # Optional custom index name - ) -> None: - """ - Creates an HNSW vector index on a specified table and vector column with - optional build and search configurations. If no configurations are provided, - default parameters from the database are used. If provided values exceed the - valid ranges, an error will be raised. - The index is always created in ONLINE mode. - - Args: - m: (Optional) Maximum number of neighbors per graph node - (Valid Range: [4, 1000]) - ef_construction: (Optional) Maximal candidates to consider when building - the graph (Valid Range: [1, 100000]) - ef_search: (Optional) Minimum candidates for top-k-nearest neighbor - queries (Valid Range: [1, 100000]) - index_name: (Optional) Custom index name. Defaults to - __idx - """ - # Set default index name if not provided - distance_func_name = HANA_DISTANCE_FUNCTION[self.distance_strategy][0] - default_index_name = f"{self.table_name}_{distance_func_name}_idx" - # Use provided index_name or default - index_name = ( - HanaDB._sanitize_name(index_name) if index_name else default_index_name - ) - # Initialize build_config and search_config as empty dictionaries - build_config = {} - search_config = {} - - # Validate and add m parameter to build_config if provided - if m is not None: - m = HanaDB._sanitize_int(m) - if not (4 <= m <= 1000): - raise ValueError("M must be in the range [4, 1000]") - build_config["M"] = m - - # Validate and add ef_construction to build_config if provided - if ef_construction is not None: - ef_construction = HanaDB._sanitize_int(ef_construction) - if not (1 <= ef_construction <= 100000): - raise ValueError("efConstruction must be in the range [1, 100000]") - build_config["efConstruction"] = ef_construction - - # Validate and add ef_search to search_config if provided - if ef_search is not None: - ef_search = HanaDB._sanitize_int(ef_search) - if not (1 <= ef_search <= 100000): - raise ValueError("efSearch must be in the range [1, 100000]") - search_config["efSearch"] = ef_search - - # Convert build_config and search_config to JSON strings if they contain values - build_config_str = json.dumps(build_config) if build_config else "" - search_config_str = json.dumps(search_config) if search_config else "" - - # Create the index SQL string with the ONLINE keyword - sql_str = ( - f'CREATE HNSW VECTOR INDEX {index_name} ON "{self.table_name}" ' - f'("{self.vector_column}") ' - f"SIMILARITY FUNCTION {distance_func_name} " - ) - - # Append build_config to the SQL string if provided - if build_config_str: - sql_str += f"BUILD CONFIGURATION '{build_config_str}' " - - # Append search_config to the SQL string if provided - if search_config_str: - sql_str += f"SEARCH CONFIGURATION '{search_config_str}' " - - # Always add the ONLINE option - sql_str += "ONLINE " - cur = self.connection.cursor() - try: - cur.execute(sql_str) - finally: - cur.close() - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - embeddings: Optional[List[List[float]]] = None, - **kwargs: Any, - ) -> List[str]: - """Add more texts to the vectorstore. - - Args: - texts (Iterable[str]): Iterable of strings/text to add to the vectorstore. - metadatas (Optional[List[dict]], optional): Optional list of metadatas. - Defaults to None. - embeddings (Optional[List[List[float]]], optional): Optional pre-generated - embeddings. Defaults to None. - - Returns: - List[str]: empty list - """ - # Create all embeddings of the texts beforehand to improve performance - if embeddings is None: - embeddings = self.embedding.embed_documents(list(texts)) - - # Create sql parameters array - sql_params = [] - for i, text in enumerate(texts): - metadata = metadatas[i] if metadatas else {} - metadata, extracted_special_metadata = self._split_off_special_metadata( - metadata - ) - embedding = ( - embeddings[i] - if embeddings - else self.embedding.embed_documents([text])[0] - ) - sql_params.append( - ( - text, - json.dumps(HanaDB._sanitize_metadata_keys(metadata)), - f"[{','.join(map(str, embedding))}]", - *extracted_special_metadata, - ) - ) - - # Insert data into the table - cur = self.connection.cursor() - try: - specific_metadata_columns_string = '", "'.join( - self.specific_metadata_columns - ) - if specific_metadata_columns_string: - specific_metadata_columns_string = ( - ', "' + specific_metadata_columns_string + '"' - ) - sql_str = ( - f'INSERT INTO "{self.table_name}" ("{self.content_column}", ' - f'"{self.metadata_column}", ' - f'"{self.vector_column}"{specific_metadata_columns_string}) ' - f"VALUES (?, ?, TO_REAL_VECTOR (?)" - f"{', ?' * len(self.specific_metadata_columns)});" - ) - cur.executemany(sql_str, sql_params) - finally: - cur.close() - return [] - - @classmethod - def from_texts( # type: ignore[override] - cls, - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - connection: dbapi.Connection = None, - distance_strategy: DistanceStrategy = default_distance_strategy, - table_name: str = default_table_name, - content_column: str = default_content_column, - metadata_column: str = default_metadata_column, - vector_column: str = default_vector_column, - vector_column_length: int = default_vector_column_length, - *, - specific_metadata_columns: Optional[List[str]] = None, - ) -> Self: - """Create a HanaDB instance from raw documents. - This is a user-friendly interface that: - 1. Embeds documents. - 2. Creates a table if it does not yet exist. - 3. Adds the documents to the table. - This is intended to be a quick way to get started. - """ - - instance = cls( - connection=connection, - embedding=embedding, - distance_strategy=distance_strategy, - table_name=table_name, - content_column=content_column, - metadata_column=metadata_column, - vector_column=vector_column, - vector_column_length=vector_column_length, # -1 means dynamic length - specific_metadata_columns=specific_metadata_columns, - ) - instance.add_texts(texts, metadatas) - return instance - - def similarity_search( # type: ignore[override] - self, query: str, k: int = 4, filter: Optional[dict] = None - ) -> List[Document]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: A dictionary of metadata fields and values to filter by. - Defaults to None. - - Returns: - List of Documents most similar to the query - """ - docs_and_scores = self.similarity_search_with_score( - query=query, k=k, filter=filter - ) - return [doc for doc, _ in docs_and_scores] - - def similarity_search_with_score( - self, query: str, k: int = 4, filter: Optional[dict] = None - ) -> List[Tuple[Document, float]]: - """Return documents and score values most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: A dictionary of metadata fields and values to filter by. - Defaults to None. - - Returns: - List of tuples (containing a Document and a score) that are - most similar to the query - """ - embedding = self.embedding.embed_query(query) - return self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter - ) - - def similarity_search_with_score_and_vector_by_vector( - self, embedding: List[float], k: int = 4, filter: Optional[dict] = None - ) -> List[Tuple[Document, float, List[float]]]: - """Return docs most similar to the given embedding. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: A dictionary of metadata fields and values to filter by. - Defaults to None. - - Returns: - List of Documents most similar to the query and - score and the document's embedding vector for each - """ - result = [] - k = HanaDB._sanitize_int(k) - embedding = HanaDB._sanitize_list_float(embedding) - distance_func_name = HANA_DISTANCE_FUNCTION[self.distance_strategy][0] - embedding_as_str = "[" + ",".join(map(str, embedding)) + "]" - sql_str = ( - f"SELECT TOP {k}" - f' "{self.content_column}", ' # row[0] - f' "{self.metadata_column}", ' # row[1] - f' TO_NVARCHAR("{self.vector_column}"), ' # row[2] - f' {distance_func_name}("{self.vector_column}", TO_REAL_VECTOR (?)) AS CS ' - f'FROM "{self.table_name}"' - ) - order_str = f" order by CS {HANA_DISTANCE_FUNCTION[self.distance_strategy][1]}" - where_str, query_tuple = self._create_where_by_filter(filter) - query_params = (embedding_as_str,) + tuple(query_tuple) - sql_str = sql_str + where_str - sql_str = sql_str + order_str - try: - cur = self.connection.cursor() - cur.execute(sql_str, query_params) - if cur.has_result_set(): - rows = cur.fetchall() - for row in rows: - js = json.loads(row[1]) - doc = Document(page_content=row[0], metadata=js) - result_vector = HanaDB._parse_float_array_from_string(row[2]) - result.append((doc, row[3], result_vector)) - finally: - cur.close() - return result - - def similarity_search_with_score_by_vector( - self, embedding: List[float], k: int = 4, filter: Optional[dict] = None - ) -> List[Tuple[Document, float]]: - """Return docs most similar to the given embedding. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: A dictionary of metadata fields and values to filter by. - Defaults to None. - - Returns: - List of Documents most similar to the query and score for each - """ - whole_result = self.similarity_search_with_score_and_vector_by_vector( - embedding=embedding, k=k, filter=filter - ) - return [(result_item[0], result_item[1]) for result_item in whole_result] - - def similarity_search_by_vector( # type: ignore[override] - self, embedding: List[float], k: int = 4, filter: Optional[dict] = None - ) -> List[Document]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: A dictionary of metadata fields and values to filter by. - Defaults to None. - - Returns: - List of Documents most similar to the query vector. - """ - docs_and_scores = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter - ) - return [doc for doc, _ in docs_and_scores] - - def _create_where_by_filter(self, filter: Optional[dict]) -> Tuple[str, list[Any]]: - query_tuple: list[Any] = [] - where_str = "" - if filter: - where_str, query_tuple = self._process_filter_object(filter) - where_str = " WHERE " + where_str - return where_str, query_tuple - - def _process_filter_object(self, filter: Optional[dict]) -> Tuple[str, list[Any]]: - query_tuple = [] - where_str = "" - if filter: - for i, key in enumerate(filter.keys()): - filter_value = filter[key] - if i != 0: - where_str += " AND " - - # Handling of 'special' boolean operators "$and", "$or" - if key in LOGICAL_OPERATORS_TO_SQL: - logical_operator = LOGICAL_OPERATORS_TO_SQL[key] - logical_operands = filter_value - for j, logical_operand in enumerate(logical_operands): - if j != 0: - where_str += f" {logical_operator} " - ( - where_str_logical, - query_tuple_logical, - ) = self._process_filter_object(logical_operand) - where_str += "(" + where_str_logical + ")" - query_tuple += query_tuple_logical - continue - - operator = "=" - sql_param = "?" - - if isinstance(filter_value, bool): - query_tuple.append("true" if filter_value else "false") - elif isinstance(filter_value, int) or isinstance(filter_value, str): - query_tuple.append(filter_value) - elif isinstance(filter_value, Dict): - # Handling of 'special' operators starting with "$" - special_op = next(iter(filter_value)) - special_val = filter_value[special_op] - # "$eq", "$ne", "$lt", "$lte", "$gt", "$gte" - if special_op in COMPARISONS_TO_SQL: - operator = COMPARISONS_TO_SQL[special_op] - if isinstance(special_val, bool): - query_tuple.append("true" if special_val else "false") - elif isinstance(special_val, float): - sql_param = "CAST(? as float)" - query_tuple.append(special_val) - elif ( - isinstance(special_val, dict) - and "type" in special_val - and special_val["type"] == "date" - ): - # Date type - sql_param = "CAST(? as DATE)" - query_tuple.append(special_val["date"]) - else: - query_tuple.append(special_val) - # "$between" - elif special_op == BETWEEN_OPERATOR: - between_from = special_val[0] - between_to = special_val[1] - operator = "BETWEEN" - sql_param = "? AND ?" - query_tuple.append(between_from) - query_tuple.append(between_to) - # "$like" - elif special_op == LIKE_OPERATOR: - operator = "LIKE" - query_tuple.append(special_val) - # "$in", "$nin" - elif special_op in IN_OPERATORS_TO_SQL: - operator = IN_OPERATORS_TO_SQL[special_op] - if isinstance(special_val, list): - for i, list_entry in enumerate(special_val): - if i == 0: - sql_param = "(" - sql_param = sql_param + "?" - if i == (len(special_val) - 1): - sql_param = sql_param + ")" - else: - sql_param = sql_param + "," - query_tuple.append(list_entry) - else: - raise ValueError( - f"Unsupported value for {operator}: {special_val}" - ) - else: - raise ValueError(f"Unsupported operator: {special_op}") - else: - raise ValueError( - f"Unsupported filter data-type: {type(filter_value)}" - ) - - selector = ( - f' "{key}"' - if key in self.specific_metadata_columns - else f"JSON_VALUE({self.metadata_column}, '$.{key}')" - ) - where_str += f"{selector} {operator} {sql_param}" - - return where_str, query_tuple - - def delete( # type: ignore[override] - self, ids: Optional[List[str]] = None, filter: Optional[dict] = None - ) -> Optional[bool]: - """Delete entries by filter with metadata values - - Args: - ids: Deletion with ids is not supported! A ValueError will be raised. - filter: A dictionary of metadata fields and values to filter by. - An empty filter ({}) will delete all entries in the table. - - Returns: - Optional[bool]: True, if deletion is technically successful. - Deletion of zero entries, due to non-matching filters is a success. - """ - - if ids is not None: - raise ValueError("Deletion via ids is not supported") - - if filter is None: - raise ValueError("Parameter 'filter' is required when calling 'delete'") - - where_str, query_tuple = self._create_where_by_filter(filter) - sql_str = f'DELETE FROM "{self.table_name}" {where_str}' - - try: - cur = self.connection.cursor() - cur.execute(sql_str, query_tuple) - finally: - cur.close() - - return True - - async def adelete( # type: ignore[override] - self, ids: Optional[List[str]] = None, filter: Optional[dict] = None - ) -> Optional[bool]: - """Delete by vector ID or other criteria. - - Args: - ids: List of ids to delete. - - Returns: - Optional[bool]: True if deletion is successful, - False otherwise, None if not implemented. - """ - return await run_in_executor(None, self.delete, ids=ids, filter=filter) - - def max_marginal_relevance_search( # type: ignore[override] - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[dict] = None, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: search query text. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter: Filter on metadata properties, e.g. - { - "str_property": "foo", - "int_property": 123 - } - Returns: - List of Documents selected by maximal marginal relevance. - """ - embedding = self.embedding.embed_query(query) - return self.max_marginal_relevance_search_by_vector( - embedding=embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - ) - - def _parse_float_array_from_string(array_as_string: str) -> List[float]: # type: ignore[misc] - array_wo_brackets = array_as_string[1:-1] - return [float(x) for x in array_wo_brackets.split(",")] - - def max_marginal_relevance_search_by_vector( # type: ignore[override] - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[dict] = None, - ) -> List[Document]: - whole_result = self.similarity_search_with_score_and_vector_by_vector( - embedding=embedding, k=fetch_k, filter=filter - ) - embeddings = [result_item[2] for result_item in whole_result] - mmr_doc_indexes = maximal_marginal_relevance( - np.array(embedding), embeddings, lambda_mult=lambda_mult, k=k - ) - - return [whole_result[i][0] for i in mmr_doc_indexes] - - async def amax_marginal_relevance_search_by_vector( # type: ignore[override] - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance.""" - return await run_in_executor( - None, - self.max_marginal_relevance_search_by_vector, - embedding=embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - ) - - @staticmethod - def _cosine_relevance_score_fn(distance: float) -> float: - return distance - - def _select_relevance_score_fn(self) -> Callable[[float], float]: - """ - The 'correct' relevance function - may differ depending on a few things, including: - - the distance / similarity metric used by the VectorStore - - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) - - embedding dimensionality - - etc. - - Vectorstores should define their own selection based method of relevance. - """ - if self.distance_strategy == DistanceStrategy.COSINE: - return HanaDB._cosine_relevance_score_fn - elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE: - return HanaDB._euclidean_relevance_score_fn - else: - raise ValueError( - "Unsupported distance_strategy: {}".format(self.distance_strategy) - ) diff --git a/libs/community/langchain_community/vectorstores/matching_engine.py b/libs/community/langchain_community/vectorstores/matching_engine.py deleted file mode 100644 index d0f9c0b9a..000000000 --- a/libs/community/langchain_community/vectorstores/matching_engine.py +++ /dev/null @@ -1,606 +0,0 @@ -from __future__ import annotations - -import json -import logging -import time -import uuid -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type - -from langchain_core._api.deprecation import deprecated -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from langchain_core.vectorstores import VectorStore - -from langchain_community.utilities.vertexai import get_client_info - -if TYPE_CHECKING: - from google.cloud import storage - from google.cloud.aiplatform import MatchingEngineIndex, MatchingEngineIndexEndpoint - from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import ( - Namespace, - ) - from google.oauth2.service_account import Credentials - - from langchain_community.embeddings import TensorflowHubEmbeddings - -logger = logging.getLogger(__name__) - - -@deprecated( - since="0.0.12", - removal="1.0", - alternative_import="langchain_google_vertexai.VectorSearchVectorStore", -) -class MatchingEngine(VectorStore): - """`Google Vertex AI Vector Search` (previously Matching Engine) vector store. - - While the embeddings are stored in the Matching Engine, the embedded - documents will be stored in GCS. - - An existing Index and corresponding Endpoint are preconditions for - using this module. - - See usage in docs/integrations/vectorstores/google_vertex_ai_vector_search.ipynb - - Note that this implementation is mostly meant for reading if you are - planning to do a real time implementation. While reading is a real time - operation, updating the index takes close to one hour.""" - - def __init__( - self, - project_id: str, - index: MatchingEngineIndex, - endpoint: MatchingEngineIndexEndpoint, - embedding: Embeddings, - gcs_client: storage.Client, - gcs_bucket_name: str, - credentials: Optional[Credentials] = None, - *, - document_id_key: Optional[str] = None, - ): - """Google Vertex AI Vector Search (previously Matching Engine) - implementation of the vector store. - - While the embeddings are stored in the Matching Engine, the embedded - documents will be stored in GCS. - - An existing Index and corresponding Endpoint are preconditions for - using this module. - - See usage in - docs/integrations/vectorstores/google_vertex_ai_vector_search.ipynb. - - Note that this implementation is mostly meant for reading if you are - planning to do a real time implementation. While reading is a real time - operation, updating the index takes close to one hour. - - Attributes: - project_id: The GCS project id. - index: The created index class. See - ~:func:`MatchingEngine.from_components`. - endpoint: The created endpoint class. See - ~:func:`MatchingEngine.from_components`. - embedding: A :class:`Embeddings` that will be used for - embedding the text sent. If none is sent, then the - multilingual Tensorflow Universal Sentence Encoder will be used. - gcs_client: The GCS client. - gcs_bucket_name: The GCS bucket name. - credentials (Optional): Created GCP credentials. - document_id_key (Optional): Key for storing document ID in document - metadata. If None, document ID will not be returned in document - metadata. - """ - super().__init__() - self._validate_google_libraries_installation() - - self.project_id = project_id - self.index = index - self.endpoint = endpoint - self.embedding = embedding - self.gcs_client = gcs_client - self.credentials = credentials - self.gcs_bucket_name = gcs_bucket_name - self.document_id_key = document_id_key - - @property - def embeddings(self) -> Embeddings: - return self.embedding - - def _validate_google_libraries_installation(self) -> None: - """Validates that Google libraries that are needed are installed.""" - try: - from google.cloud import aiplatform, storage # noqa: F401 - from google.oauth2 import service_account # noqa: F401 - except ImportError: - raise ImportError( - "You must run `pip install --upgrade " - "google-cloud-aiplatform google-cloud-storage`" - "to use the MatchingEngine Vectorstore." - ) - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ) -> List[str]: - """Run more texts through the embeddings and add to the vectorstore. - - Args: - texts: Iterable of strings to add to the vectorstore. - metadatas: Optional list of metadatas associated with the texts. - kwargs: vectorstore specific parameters. - - Returns: - List of ids from adding the texts into the vectorstore. - """ - texts = list(texts) - if metadatas is not None and len(texts) != len(metadatas): - raise ValueError( - "texts and metadatas do not have the same length. Received " - f"{len(texts)} texts and {len(metadatas)} metadatas." - ) - logger.debug("Embedding documents.") - embeddings = self.embedding.embed_documents(texts) - jsons = [] - ids = [] - # Could be improved with async. - for idx, (embedding, text) in enumerate(zip(embeddings, texts)): - id = str(uuid.uuid4()) - ids.append(id) - json_: dict = {"id": id, "embedding": embedding} - if metadatas is not None: - json_["metadata"] = metadatas[idx] - jsons.append(json_) - self._upload_to_gcs(text, f"documents/{id}") - - logger.debug(f"Uploaded {len(ids)} documents to GCS.") - - # Creating json lines from the embedded documents. - result_str = "\n".join([json.dumps(x) for x in jsons]) - - filename_prefix = f"indexes/{uuid.uuid4()}" - filename = f"{filename_prefix}/{time.time()}.json" - self._upload_to_gcs(result_str, filename) - logger.debug( - f"Uploaded updated json with embeddings to " - f"{self.gcs_bucket_name}/{filename}." - ) - - self.index = self.index.update_embeddings( - contents_delta_uri=f"gs://{self.gcs_bucket_name}/{filename_prefix}/" - ) - - logger.debug("Updated index with new configuration.") - - return ids - - def _upload_to_gcs(self, data: str, gcs_location: str) -> None: - """Uploads data to gcs_location. - - Args: - data: The data that will be stored. - gcs_location: The location where the data will be stored. - """ - bucket = self.gcs_client.get_bucket(self.gcs_bucket_name) - blob = bucket.blob(gcs_location) - blob.upload_from_string(data) - - def similarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[List[Namespace]] = None, - ) -> List[Tuple[Document, float]]: - """Return docs most similar to query and their cosine distance from the query. - - Args: - query: String query look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Optional. A list of Namespaces for filtering - the matching results. - For example: - [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] - will match datapoints that satisfy "red color" but not include - datapoints with "squared shape". Please refer to - https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json - for more detail. - - Returns: - List[Tuple[Document, float]]: List of documents most similar to - the query text and cosine distance in float for each. - Lower score represents more similarity. - """ - logger.debug(f"Embedding query {query}.") - embedding_query = self.embedding.embed_query(query) - return self.similarity_search_by_vector_with_score( - embedding_query, k=k, filter=filter - ) - - def similarity_search_by_vector_with_score( - self, - embedding: List[float], - k: int = 4, - filter: Optional[List[Namespace]] = None, - ) -> List[Tuple[Document, float]]: - """Return docs most similar to the embedding and their cosine distance. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Optional. A list of Namespaces for filtering - the matching results. - For example: - [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] - will match datapoints that satisfy "red color" but not include - datapoints with "squared shape". Please refer to - https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json - for more detail. - - Returns: - List[Tuple[Document, float]]: List of documents most similar to - the query text and cosine distance in float for each. - Lower score represents more similarity. - - """ - filter = filter or [] - - # If the endpoint is public we use the find_neighbors function. - if hasattr(self.endpoint, "_public_match_client") and ( - self.endpoint._public_match_client - ): - response = self.endpoint.find_neighbors( - deployed_index_id=self._get_index_id(), - queries=[embedding], - num_neighbors=k, - filter=filter, - ) - else: - response = self.endpoint.match( - deployed_index_id=self._get_index_id(), - queries=[embedding], - num_neighbors=k, - filter=filter, - ) - - logger.debug(f"Found {len(response)} matches.") - - if len(response) == 0: - return [] - - docs: List[Tuple[Document, float]] = [] - - # I'm only getting the first one because queries receives an array - # and the similarity_search method only receives one query. This - # means that the match method will always return an array with only - # one element. - for result in response[0]: - page_content = self._download_from_gcs(f"documents/{result.id}") - # TODO: return all metadata. - metadata = {} - if self.document_id_key is not None: - metadata[self.document_id_key] = result.id - document = Document( - page_content=page_content, - metadata=metadata, - ) - docs.append((document, result.distance)) - - logger.debug("Downloaded documents for query.") - - return docs - - def similarity_search( - self, - query: str, - k: int = 4, - filter: Optional[List[Namespace]] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs most similar to query. - - Args: - query: The string that will be used to search for similar documents. - k: The amount of neighbors that will be retrieved. - filter: Optional. A list of Namespaces for filtering the matching results. - For example: - [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] - will match datapoints that satisfy "red color" but not include - datapoints with "squared shape". Please refer to - https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json - for more detail. - - Returns: - A list of k matching documents. - """ - docs_and_scores = self.similarity_search_with_score( - query, k=k, filter=filter, **kwargs - ) - - return [doc for doc, _ in docs_and_scores] - - def similarity_search_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[List[Namespace]] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs most similar to the embedding. - - Args: - embedding: Embedding to look up documents similar to. - k: The amount of neighbors that will be retrieved. - filter: Optional. A list of Namespaces for filtering the matching results. - For example: - [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] - will match datapoints that satisfy "red color" but not include - datapoints with "squared shape". Please refer to - https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json - for more detail. - - Returns: - A list of k matching documents. - """ - docs_and_scores = self.similarity_search_by_vector_with_score( - embedding, k=k, filter=filter, **kwargs - ) - - return [doc for doc, _ in docs_and_scores] - - def _get_index_id(self) -> str: - """Gets the correct index id for the endpoint. - - Returns: - The index id if found (which should be found) or throws - ValueError otherwise. - """ - for index in self.endpoint.deployed_indexes: - if index.index == self.index.resource_name: - return index.id - - raise ValueError( - f"No index with id {self.index.resource_name} " - f"deployed on endpoint " - f"{self.endpoint.display_name}." - ) - - def _download_from_gcs(self, gcs_location: str) -> str: - """Downloads from GCS in text format. - - Args: - gcs_location: The location where the file is located. - - Returns: - The string contents of the file. - """ - bucket = self.gcs_client.get_bucket(self.gcs_bucket_name) - blob = bucket.blob(gcs_location) - return blob.download_as_string() - - @classmethod - def from_texts( - cls: Type["MatchingEngine"], - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ) -> "MatchingEngine": - """Use from components instead.""" - raise NotImplementedError( - "This method is not implemented. Instead, you should initialize the class" - " with `MatchingEngine.from_components(...)` and then call " - "`add_texts`" - ) - - @classmethod - def from_components( - cls: Type["MatchingEngine"], - project_id: str, - region: str, - gcs_bucket_name: str, - index_id: str, - endpoint_id: str, - credentials_path: Optional[str] = None, - embedding: Optional[Embeddings] = None, - **kwargs: Any, - ) -> "MatchingEngine": - """Takes the object creation out of the constructor. - - Args: - project_id: The GCP project id. - region: The default location making the API calls. It must have - the same location as the GCS bucket and must be regional. - gcs_bucket_name: The location where the vectors will be stored in - order for the index to be created. - index_id: The id of the created index. - endpoint_id: The id of the created endpoint. - credentials_path: (Optional) The path of the Google credentials on - the local file system. - embedding: The :class:`Embeddings` that will be used for - embedding the texts. - kwargs: Additional keyword arguments to pass to MatchingEngine.__init__(). - - Returns: - A configured MatchingEngine with the texts added to the index. - """ - gcs_bucket_name = cls._validate_gcs_bucket(gcs_bucket_name) - credentials = cls._create_credentials_from_file(credentials_path) - index = cls._create_index_by_id(index_id, project_id, region, credentials) - endpoint = cls._create_endpoint_by_id( - endpoint_id, - project_id, - region, - credentials, - ) - - gcs_client = cls._get_gcs_client(credentials, project_id) - cls._init_aiplatform(project_id, region, gcs_bucket_name, credentials) - - return cls( - project_id=project_id, - index=index, - endpoint=endpoint, - embedding=embedding or cls._get_default_embeddings(), - gcs_client=gcs_client, - credentials=credentials, - gcs_bucket_name=gcs_bucket_name, - **kwargs, - ) - - @classmethod - def _validate_gcs_bucket(cls, gcs_bucket_name: str) -> str: - """Validates the gcs_bucket_name as a bucket name. - - Args: - gcs_bucket_name: The received bucket uri. - - Returns: - A valid gcs_bucket_name or throws ValueError if full path is - provided. - """ - gcs_bucket_name = gcs_bucket_name.replace("gs://", "") - if "/" in gcs_bucket_name: - raise ValueError( - f"The argument gcs_bucket_name should only be " - f"the bucket name. Received {gcs_bucket_name}" - ) - return gcs_bucket_name - - @classmethod - def _create_credentials_from_file( - cls, json_credentials_path: Optional[str] - ) -> Optional[Credentials]: - """Creates credentials for GCP. - - Args: - json_credentials_path: The path on the file system where the - credentials are stored. - - Returns: - An optional of Credentials or None, in which case the default - will be used. - """ - - from google.oauth2 import service_account - - credentials = None - if json_credentials_path is not None: - credentials = service_account.Credentials.from_service_account_file( - json_credentials_path - ) - - return credentials - - @classmethod - def _create_index_by_id( - cls, index_id: str, project_id: str, region: str, credentials: "Credentials" - ) -> MatchingEngineIndex: - """Creates a MatchingEngineIndex object by id. - - Args: - index_id: The created index id. - project_id: The project to retrieve index from. - region: Location to retrieve index from. - credentials: GCS credentials. - - Returns: - A configured MatchingEngineIndex. - """ - - from google.cloud import aiplatform - - logger.debug(f"Creating matching engine index with id {index_id}.") - return aiplatform.MatchingEngineIndex( - index_name=index_id, - project=project_id, - location=region, - credentials=credentials, - ) - - @classmethod - def _create_endpoint_by_id( - cls, endpoint_id: str, project_id: str, region: str, credentials: "Credentials" - ) -> MatchingEngineIndexEndpoint: - """Creates a MatchingEngineIndexEndpoint object by id. - - Args: - endpoint_id: The created endpoint id. - project_id: The project to retrieve index from. - region: Location to retrieve index from. - credentials: GCS credentials. - - Returns: - A configured MatchingEngineIndexEndpoint. - """ - - from google.cloud import aiplatform - - logger.debug(f"Creating endpoint with id {endpoint_id}.") - return aiplatform.MatchingEngineIndexEndpoint( - index_endpoint_name=endpoint_id, - project=project_id, - location=region, - credentials=credentials, - ) - - @classmethod - def _get_gcs_client( - cls, credentials: "Credentials", project_id: str - ) -> "storage.Client": - """Lazily creates a GCS client. - - Returns: - A configured GCS client. - """ - - from google.cloud import storage - - return storage.Client( - credentials=credentials, - project=project_id, - client_info=get_client_info(module="vertex-ai-matching-engine"), - ) - - @classmethod - def _init_aiplatform( - cls, - project_id: str, - region: str, - gcs_bucket_name: str, - credentials: "Credentials", - ) -> None: - """Configures the aiplatform library. - - Args: - project_id: The GCP project id. - region: The default location making the API calls. It must have - the same location as the GCS bucket and must be regional. - gcs_bucket_name: GCS staging location. - credentials: The GCS Credentials object. - """ - - from google.cloud import aiplatform - - logger.debug( - f"Initializing AI Platform for project {project_id} on " - f"{region} and for {gcs_bucket_name}." - ) - aiplatform.init( - project=project_id, - location=region, - staging_bucket=gcs_bucket_name, - credentials=credentials, - ) - - @classmethod - def _get_default_embeddings(cls) -> "TensorflowHubEmbeddings": - """This function returns the default embedding. - - Returns: - Default TensorflowHubEmbeddings to use. - """ - - from langchain_community.embeddings import TensorflowHubEmbeddings - - return TensorflowHubEmbeddings() diff --git a/libs/community/langchain_community/vectorstores/milvus.py b/libs/community/langchain_community/vectorstores/milvus.py index 2a804bea2..5ebd1b72a 100644 --- a/libs/community/langchain_community/vectorstores/milvus.py +++ b/libs/community/langchain_community/vectorstores/milvus.py @@ -32,7 +32,7 @@ alternative_import="langchain_milvus.MilvusVectorStore", ) class Milvus(VectorStore): - """`Milvus` vector store. + """`Milvus` vector store. DO NOT USE. KEPT FOR BACKWARDS COMPATIBILITY. You need to install `pymilvus` and run Milvus. diff --git a/libs/community/langchain_community/vectorstores/mongodb_atlas.py b/libs/community/langchain_community/vectorstores/mongodb_atlas.py deleted file mode 100644 index 10fab4ec1..000000000 --- a/libs/community/langchain_community/vectorstores/mongodb_atlas.py +++ /dev/null @@ -1,376 +0,0 @@ -from __future__ import annotations - -import logging -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Generator, - Iterable, - List, - Optional, - Tuple, - TypeVar, - Union, -) - -import numpy as np -from langchain_core._api.deprecation import deprecated -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from langchain_core.vectorstores import VectorStore - -from langchain_community.vectorstores.utils import maximal_marginal_relevance - -if TYPE_CHECKING: - from pymongo.collection import Collection - -MongoDBDocumentType = TypeVar("MongoDBDocumentType", bound=Dict[str, Any]) - -logger = logging.getLogger(__name__) - -DEFAULT_INSERT_BATCH_SIZE = 100 - - -@deprecated( - since="0.0.25", - removal="1.0", - alternative_import="langchain_mongodb.MongoDBAtlasVectorSearch", -) -class MongoDBAtlasVectorSearch(VectorStore): - """`MongoDB Atlas Vector Search` vector store. - - To use, you should have both: - - the ``pymongo`` python package installed - - a connection string associated with a MongoDB Atlas Cluster having deployed an - Atlas Search index - - Example: - .. code-block:: python - - from langchain_community.vectorstores import MongoDBAtlasVectorSearch - from langchain_community.embeddings.openai import OpenAIEmbeddings - from pymongo import MongoClient - - mongo_client = MongoClient("") - collection = mongo_client[""][""] - embeddings = OpenAIEmbeddings() - vectorstore = MongoDBAtlasVectorSearch(collection, embeddings) - """ - - def __init__( - self, - collection: Collection[MongoDBDocumentType], - embedding: Embeddings, - *, - index_name: str = "default", - text_key: str = "text", - embedding_key: str = "embedding", - relevance_score_fn: str = "cosine", - ): - """ - Args: - collection: MongoDB collection to add the texts to. - embedding: Text embedding model to use. - text_key: MongoDB field that will contain the text for each - document. - embedding_key: MongoDB field that will contain the embedding for - each document. - index_name: Name of the Atlas Search index. - relevance_score_fn: The similarity score used for the index. - Currently supported: Euclidean, cosine, and dot product. - """ - self._collection = collection - self._embedding = embedding - self._index_name = index_name - self._text_key = text_key - self._embedding_key = embedding_key - self._relevance_score_fn = relevance_score_fn - - @property - def embeddings(self) -> Embeddings: - return self._embedding - - def _select_relevance_score_fn(self) -> Callable[[float], float]: - if self._relevance_score_fn == "euclidean": - return self._euclidean_relevance_score_fn - elif self._relevance_score_fn == "dotProduct": - return self._max_inner_product_relevance_score_fn - elif self._relevance_score_fn == "cosine": - return self._cosine_relevance_score_fn - else: - raise NotImplementedError( - f"No relevance score function for ${self._relevance_score_fn}" - ) - - @classmethod - def from_connection_string( - cls, - connection_string: str, - namespace: str, - embedding: Embeddings, - **kwargs: Any, - ) -> MongoDBAtlasVectorSearch: - """Construct a `MongoDB Atlas Vector Search` vector store - from a MongoDB connection URI. - - Args: - connection_string: A valid MongoDB connection URI. - namespace: A valid MongoDB namespace (database and collection). - embedding: The text embedding model to use for the vector store. - - Returns: - A new MongoDBAtlasVectorSearch instance. - - """ - try: - from importlib.metadata import version - - from pymongo import MongoClient - from pymongo.driver_info import DriverInfo - except ImportError: - raise ImportError( - "Could not import pymongo, please install it with " - "`pip install pymongo`." - ) - client: MongoClient = MongoClient( - connection_string, - driver=DriverInfo(name="Langchain", version=version("langchain")), - ) - db_name, collection_name = namespace.split(".") - collection = client[db_name][collection_name] - return cls(collection, embedding, **kwargs) - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[Dict[str, Any]]] = None, - **kwargs: Any, - ) -> List: - """Run more texts through the embeddings and add to the vectorstore. - - Args: - texts: Iterable of strings to add to the vectorstore. - metadatas: Optional list of metadatas associated with the texts. - - Returns: - List of ids from adding the texts into the vectorstore. - """ - batch_size = kwargs.get("batch_size", DEFAULT_INSERT_BATCH_SIZE) - _metadatas: Union[List, Generator] = metadatas or ({} for _ in texts) - texts_batch = [] - metadatas_batch = [] - result_ids = [] - for i, (text, metadata) in enumerate(zip(texts, _metadatas)): - texts_batch.append(text) - metadatas_batch.append(metadata) - if (i + 1) % batch_size == 0: - result_ids.extend(self._insert_texts(texts_batch, metadatas_batch)) - texts_batch = [] - metadatas_batch = [] - if texts_batch: - result_ids.extend(self._insert_texts(texts_batch, metadatas_batch)) - return result_ids - - def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> List: - if not texts: - return [] - # Embed and create the documents - embeddings = self._embedding.embed_documents(texts) - to_insert = [ - {self._text_key: t, self._embedding_key: embedding, **m} - for t, m, embedding in zip(texts, metadatas, embeddings) - ] - # insert the documents in MongoDB Atlas - insert_result = self._collection.insert_many(to_insert) - return insert_result.inserted_ids - - def _similarity_search_with_score( - self, - embedding: List[float], - k: int = 4, - pre_filter: Optional[Dict] = None, - post_filter_pipeline: Optional[List[Dict]] = None, - ) -> List[Tuple[Document, float]]: - params = { - "queryVector": embedding, - "path": self._embedding_key, - "numCandidates": k * 10, - "limit": k, - "index": self._index_name, - } - if pre_filter: - params["filter"] = pre_filter - query = {"$vectorSearch": params} - - pipeline = [ - query, - {"$set": {"score": {"$meta": "vectorSearchScore"}}}, - ] - if post_filter_pipeline is not None: - pipeline.extend(post_filter_pipeline) - cursor = self._collection.aggregate(pipeline) - docs = [] - for res in cursor: - text = res.pop(self._text_key) - score = res.pop("score") - docs.append((Document(page_content=text, metadata=res), score)) - return docs - - def similarity_search_with_score( - self, - query: str, - k: int = 4, - pre_filter: Optional[Dict] = None, - post_filter_pipeline: Optional[List[Dict]] = None, - ) -> List[Tuple[Document, float]]: - """Return MongoDB documents most similar to the given query and their scores. - - Uses the vectorSearch operator available in MongoDB Atlas Search. - For more: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/ - - Args: - query: Text to look up documents similar to. - k: (Optional) number of documents to return. Defaults to 4. - pre_filter: (Optional) dictionary of argument(s) to prefilter document - fields on. - post_filter_pipeline: (Optional) Pipeline of MongoDB aggregation stages - following the vectorSearch stage. - - Returns: - List of documents most similar to the query and their scores. - """ - embedding = self._embedding.embed_query(query) - docs = self._similarity_search_with_score( - embedding, - k=k, - pre_filter=pre_filter, - post_filter_pipeline=post_filter_pipeline, - ) - return docs - - def similarity_search( - self, - query: str, - k: int = 4, - pre_filter: Optional[Dict] = None, - post_filter_pipeline: Optional[List[Dict]] = None, - **kwargs: Any, - ) -> List[Document]: - """Return MongoDB documents most similar to the given query. - - Uses the vectorSearch operator available in MongoDB Atlas Search. - For more: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/ - - Args: - query: Text to look up documents similar to. - k: (Optional) number of documents to return. Defaults to 4. - pre_filter: (Optional) dictionary of argument(s) to prefilter document - fields on. - post_filter_pipeline: (Optional) Pipeline of MongoDB aggregation stages - following the vectorSearch stage. - - Returns: - List of documents most similar to the query and their scores. - """ - additional = kwargs.get("additional") - docs_and_scores = self.similarity_search_with_score( - query, - k=k, - pre_filter=pre_filter, - post_filter_pipeline=post_filter_pipeline, - ) - - if additional and "similarity_score" in additional: - for doc, score in docs_and_scores: - doc.metadata["score"] = score - return [doc for doc, _ in docs_and_scores] - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - pre_filter: Optional[Dict] = None, - post_filter_pipeline: Optional[List[Dict]] = None, - **kwargs: Any, - ) -> List[Document]: - """Return documents selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: (Optional) number of documents to return. Defaults to 4. - fetch_k: (Optional) number of documents to fetch before passing to MMR - algorithm. Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - pre_filter: (Optional) dictionary of argument(s) to prefilter on document - fields. - post_filter_pipeline: (Optional) pipeline of MongoDB aggregation stages - following the vectorSearch stage. - Returns: - List of documents selected by maximal marginal relevance. - """ - query_embedding = self._embedding.embed_query(query) - docs = self._similarity_search_with_score( - query_embedding, - k=fetch_k, - pre_filter=pre_filter, - post_filter_pipeline=post_filter_pipeline, - ) - mmr_doc_indexes = maximal_marginal_relevance( - np.array(query_embedding), - [doc.metadata[self._embedding_key] for doc, _ in docs], - k=k, - lambda_mult=lambda_mult, - ) - mmr_docs = [docs[i][0] for i in mmr_doc_indexes] - return mmr_docs - - @classmethod - def from_texts( - cls, - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[Dict]] = None, - collection: Optional[Collection[MongoDBDocumentType]] = None, - **kwargs: Any, - ) -> MongoDBAtlasVectorSearch: - """Construct a `MongoDB Atlas Vector Search` vector store from raw documents. - - This is a user-friendly interface that: - 1. Embeds documents. - 2. Adds the documents to a provided MongoDB Atlas Vector Search index - (Lucene) - - This is intended to be a quick way to get started. - - Example: - .. code-block:: python - from pymongo import MongoClient - - from langchain_community.vectorstores import MongoDBAtlasVectorSearch - from langchain_community.embeddings import OpenAIEmbeddings - - mongo_client = MongoClient("") - collection = mongo_client[""][""] - embeddings = OpenAIEmbeddings() - vectorstore = MongoDBAtlasVectorSearch.from_texts( - texts, - embeddings, - metadatas=metadatas, - collection=collection - ) - """ - if collection is None: - raise ValueError("Must provide 'collection' named parameter.") - vectorstore = cls(collection, embedding, **kwargs) - vectorstore.add_texts(texts, metadatas=metadatas) - return vectorstore diff --git a/libs/community/langchain_community/vectorstores/neo4j_vector.py b/libs/community/langchain_community/vectorstores/neo4j_vector.py deleted file mode 100644 index 57c7202b3..000000000 --- a/libs/community/langchain_community/vectorstores/neo4j_vector.py +++ /dev/null @@ -1,1688 +0,0 @@ -from __future__ import annotations - -import enum -import logging -import os -from hashlib import md5 -from typing import ( - Any, - Callable, - Dict, - Iterable, - List, - Optional, - Tuple, - Type, -) - -import numpy as np -from langchain_core._api.deprecation import deprecated -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from langchain_core.utils import get_from_dict_or_env -from langchain_core.vectorstores import VectorStore - -from langchain_community.graphs import Neo4jGraph -from langchain_community.vectorstores.utils import ( - DistanceStrategy, - maximal_marginal_relevance, -) - -DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE -DISTANCE_MAPPING = { - DistanceStrategy.EUCLIDEAN_DISTANCE: "euclidean", - DistanceStrategy.COSINE: "cosine", -} - -COMPARISONS_TO_NATIVE = { - "$eq": "=", - "$ne": "<>", - "$lt": "<", - "$lte": "<=", - "$gt": ">", - "$gte": ">=", -} - -SPECIAL_CASED_OPERATORS = { - "$in", - "$nin", - "$between", -} - -TEXT_OPERATORS = { - "$like", - "$ilike", -} - -LOGICAL_OPERATORS = {"$and", "$or"} - -SUPPORTED_OPERATORS = ( - set(COMPARISONS_TO_NATIVE) - .union(TEXT_OPERATORS) - .union(LOGICAL_OPERATORS) - .union(SPECIAL_CASED_OPERATORS) -) - - -@deprecated( - since="0.3.8", - removal="1.0", - alternative_import="langchain_neo4j.vectorstores.neo4j_vector.SearchType", -) -class SearchType(str, enum.Enum): - """Enumerator of the Distance strategies.""" - - VECTOR = "vector" - HYBRID = "hybrid" - - -DEFAULT_SEARCH_TYPE = SearchType.VECTOR - - -@deprecated( - since="0.3.8", - removal="1.0", - alternative_import="langchain_neo4j.vectorstores.neo4j_vector.IndexType", -) -class IndexType(str, enum.Enum): - """Enumerator of the index types.""" - - NODE = "NODE" - RELATIONSHIP = "RELATIONSHIP" - - -DEFAULT_INDEX_TYPE = IndexType.NODE - - -@deprecated( - since="0.3.8", - removal="1.0", - alternative_import="langchain_neo4j.vectorstores.neo4j_vector._get_search_index_query", -) -def _get_search_index_query( - search_type: SearchType, index_type: IndexType = DEFAULT_INDEX_TYPE -) -> str: - if index_type == IndexType.NODE: - type_to_query_map = { - SearchType.VECTOR: ( - "CALL db.index.vector.queryNodes($index, $k, $embedding) " - "YIELD node, score " - ), - SearchType.HYBRID: ( - "CALL { " - "CALL db.index.vector.queryNodes($index, $k, $embedding) " - "YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " - "UNWIND nodes AS n " - # We use 0 as min - "RETURN n.node AS node, (n.score / max) AS score UNION " - "CALL db.index.fulltext.queryNodes($keyword_index, $query, " - "{limit: $k}) YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " - "UNWIND nodes AS n " - # We use 0 as min - "RETURN n.node AS node, (n.score / max) AS score " - "} " - # dedup - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $k " - ), - } - return type_to_query_map[search_type] - else: - return ( - "CALL db.index.vector.queryRelationships($index, $k, $embedding) " - "YIELD relationship, score " - ) - - -@deprecated( - since="0.3.8", - removal="1.0", - alternative_import="langchain_neo4j.vectorstores.neo4j_vector.check_if_not_null", -) -def check_if_not_null(props: List[str], values: List[Any]) -> None: - """Check if the values are not None or empty string""" - for prop, value in zip(props, values): - if not value: - raise ValueError(f"Parameter `{prop}` must not be None or empty string") - - -@deprecated( - since="0.3.8", - removal="1.0", - alternative_import="langchain_neo4j.vectorstores.neo4j_vector.sort_by_index_name", -) -def sort_by_index_name( - lst: List[Dict[str, Any]], index_name: str -) -> List[Dict[str, Any]]: - """Sort first element to match the index_name if exists""" - return sorted(lst, key=lambda x: x.get("name") != index_name) - - -@deprecated( - since="0.3.8", - removal="1.0", - alternative_import="langchain_neo4j.vectorstores.neo4j_vector.remove_lucene_chars", -) -def remove_lucene_chars(text: str) -> str: - """Remove Lucene special characters""" - special_chars = [ - "+", - "-", - "&", - "|", - "!", - "(", - ")", - "{", - "}", - "[", - "]", - "^", - '"', - "~", - "*", - "?", - ":", - "\\", - ] - for char in special_chars: - if char in text: - text = text.replace(char, " ") - return text.strip() - - -@deprecated( - since="0.3.8", - removal="1.0", - alternative_import="langchain_neo4j.vectorstores.neo4j_vector.dict_to_yaml_str", -) -def dict_to_yaml_str(input_dict: Dict, indent: int = 0) -> str: - """ - Convert a dictionary to a YAML-like string without using external libraries. - - Parameters: - - input_dict (dict): The dictionary to convert. - - indent (int): The current indentation level. - - Returns: - - str: The YAML-like string representation of the input dictionary. - """ - yaml_str = "" - for key, value in input_dict.items(): - padding = " " * indent - if isinstance(value, dict): - yaml_str += f"{padding}{key}:\n{dict_to_yaml_str(value, indent + 1)}" - elif isinstance(value, list): - yaml_str += f"{padding}{key}:\n" - for item in value: - yaml_str += f"{padding}- {item}\n" - else: - yaml_str += f"{padding}{key}: {value}\n" - return yaml_str - - -@deprecated( - since="0.3.8", - removal="1.0", - alternative_import="langchain_neo4j.vectorstores.neo4j_vector.combine_queries", -) -def combine_queries( - input_queries: List[Tuple[str, Dict[str, Any]]], operator: str -) -> Tuple[str, Dict[str, Any]]: - """Combine multiple queries with an operator.""" - - # Initialize variables to hold the combined query and parameters - combined_query: str = "" - combined_params: Dict = {} - param_counter: Dict = {} - - for query, params in input_queries: - # Process each query fragment and its parameters - new_query = query - for param, value in params.items(): - # Update the parameter name to ensure uniqueness - if param in param_counter: - param_counter[param] += 1 - else: - param_counter[param] = 1 - new_param_name = f"{param}_{param_counter[param]}" - - # Replace the parameter in the query fragment - new_query = new_query.replace(f"${param}", f"${new_param_name}") - # Add the parameter to the combined parameters dictionary - combined_params[new_param_name] = value - - # Combine the query fragments with an AND operator - if combined_query: - combined_query += f" {operator} " - combined_query += f"({new_query})" - - return combined_query, combined_params - - -@deprecated( - since="0.3.8", - removal="1.0", - alternative_import="langchain_neo4j.vectorstores.neo4j_vector.collect_params", -) -def collect_params( - input_data: List[Tuple[str, Dict[str, str]]], -) -> Tuple[List[str], Dict[str, Any]]: - """Transform the input data into the desired format. - - Args: - - input_data (list of tuples): Input data to transform. - Each tuple contains a string and a dictionary. - - Returns: - - tuple: A tuple containing a list of strings and a dictionary. - """ - # Initialize variables to hold the output parts - query_parts = [] - params = {} - - # Loop through each item in the input data - for query_part, param in input_data: - # Append the query part to the list - query_parts.append(query_part) - # Update the params dictionary with the param dictionary - params.update(param) - - # Return the transformed data - return (query_parts, params) - - -@deprecated( - since="0.3.8", - removal="1.0", - alternative_import="langchain_neo4j.vectorstores.neo4j_vector._handle_field_filter", -) -def _handle_field_filter( - field: str, value: Any, param_number: int = 1 -) -> Tuple[str, Dict]: - """Create a filter for a specific field. - - Args: - field: name of field - value: value to filter - If provided as is then this will be an equality filter - If provided as a dictionary then this will be a filter, the key - will be the operator and the value will be the value to filter by - param_number: sequence number of parameters used to map between param - dict and Cypher snippet - - Returns a tuple of - - Cypher filter snippet - - Dictionary with parameters used in filter snippet - """ - if not isinstance(field, str): - raise ValueError( - f"field should be a string but got: {type(field)} with value: {field}" - ) - - if field.startswith("$"): - raise ValueError( - f"Invalid filter condition. Expected a field but got an operator: {field}" - ) - - # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters - if not field.isidentifier(): - raise ValueError(f"Invalid field name: {field}. Expected a valid identifier.") - - if isinstance(value, dict): - # This is a filter specification - if len(value) != 1: - raise ValueError( - "Invalid filter condition. Expected a value which " - "is a dictionary with a single key that corresponds to an operator " - f"but got a dictionary with {len(value)} keys. The first few " - f"keys are: {list(value.keys())[:3]}" - ) - operator, filter_value = list(value.items())[0] - # Verify that that operator is an operator - if operator not in SUPPORTED_OPERATORS: - raise ValueError( - f"Invalid operator: {operator}. Expected one of {SUPPORTED_OPERATORS}" - ) - else: # Then we assume an equality operator - operator = "$eq" - filter_value = value - - if operator in COMPARISONS_TO_NATIVE: - # Then we implement an equality filter - # native is trusted input - native = COMPARISONS_TO_NATIVE[operator] - query_snippet = f"n.`{field}` {native} $param_{param_number}" - query_param = {f"param_{param_number}": filter_value} - return (query_snippet, query_param) - elif operator == "$between": - low, high = filter_value - query_snippet = ( - f"$param_{param_number}_low <= n.`{field}` <= $param_{param_number}_high" - ) - query_param = { - f"param_{param_number}_low": low, - f"param_{param_number}_high": high, - } - return (query_snippet, query_param) - - elif operator in {"$in", "$nin", "$like", "$ilike"}: - # We'll do force coercion to text - if operator in {"$in", "$nin"}: - for val in filter_value: - if not isinstance(val, (str, int, float)): - raise NotImplementedError( - f"Unsupported type: {type(val)} for value: {val}" - ) - if operator in {"$in"}: - query_snippet = f"n.`{field}` IN $param_{param_number}" - query_param = {f"param_{param_number}": filter_value} - return (query_snippet, query_param) - elif operator in {"$nin"}: - query_snippet = f"n.`{field}` NOT IN $param_{param_number}" - query_param = {f"param_{param_number}": filter_value} - return (query_snippet, query_param) - elif operator in {"$like"}: - query_snippet = f"n.`{field}` CONTAINS $param_{param_number}" - query_param = {f"param_{param_number}": filter_value.rstrip("%")} - return (query_snippet, query_param) - elif operator in {"$ilike"}: - query_snippet = f"toLower(n.`{field}`) CONTAINS $param_{param_number}" - query_param = {f"param_{param_number}": filter_value.rstrip("%")} - return (query_snippet, query_param) - else: - raise NotImplementedError() - else: - raise NotImplementedError() - - -@deprecated( - since="0.3.8", - removal="1.0", - alternative_import="langchain_neo4j.vectorstores.neo4j_vector.construct_metadata_filter", -) -def construct_metadata_filter(filter: Dict[str, Any]) -> Tuple[str, Dict]: - """Construct a metadata filter. - - Args: - filter: A dictionary representing the filter condition. - - Returns: - Tuple[str, Dict] - """ - - if isinstance(filter, dict): - if len(filter) == 1: - # The only operators allowed at the top level are $AND and $OR - # First check if an operator or a field - key, value = list(filter.items())[0] - if key.startswith("$"): - # Then it's an operator - if key.lower() not in ["$and", "$or"]: - raise ValueError( - f"Invalid filter condition. Expected $and or $or but got: {key}" - ) - else: - # Then it's a field - return _handle_field_filter(key, filter[key]) - - # Here we handle the $and and $or operators - if not isinstance(value, list): - raise ValueError( - f"Expected a list, but got {type(value)} for value: {value}" - ) - if key.lower() == "$and": - and_ = combine_queries( - [construct_metadata_filter(el) for el in value], "AND" - ) - if len(and_) >= 1: - return and_ - else: - raise ValueError( - "Invalid filter condition. Expected a dictionary " - "but got an empty dictionary" - ) - elif key.lower() == "$or": - or_ = combine_queries( - [construct_metadata_filter(el) for el in value], "OR" - ) - if len(or_) >= 1: - return or_ - else: - raise ValueError( - "Invalid filter condition. Expected a dictionary " - "but got an empty dictionary" - ) - else: - raise ValueError( - f"Invalid filter condition. Expected $and or $or but got: {key}" - ) - elif len(filter) > 1: - # Then all keys have to be fields (they cannot be operators) - for key in filter.keys(): - if key.startswith("$"): - raise ValueError( - f"Invalid filter condition. Expected a field but got: {key}" - ) - # These should all be fields and combined using an $and operator - and_multiple = collect_params( - [ - _handle_field_filter(k, v, index) - for index, (k, v) in enumerate(filter.items()) - ] - ) - if len(and_multiple) >= 1: - return " AND ".join(and_multiple[0]), and_multiple[1] - else: - raise ValueError( - "Invalid filter condition. Expected a dictionary " - "but got an empty dictionary" - ) - else: - raise ValueError("Got an empty dictionary for filters.") - - -@deprecated( - since="0.3.8", - removal="1.0", - alternative_import="langchain_neo4j.Neo4jVector", -) -class Neo4jVector(VectorStore): - """`Neo4j` vector index. - - To use, you should have the ``neo4j`` python package installed. - - Args: - url: Neo4j connection url - username: Neo4j username. - password: Neo4j password - database: Optionally provide Neo4j database - Defaults to "neo4j" - embedding: Any embedding function implementing - `langchain.embeddings.base.Embeddings` interface. - distance_strategy: The distance strategy to use. (default: COSINE) - search_type: The type of search to be performed, either - 'vector' or 'hybrid' - node_label: The label used for nodes in the Neo4j database. - (default: "Chunk") - embedding_node_property: The property name in Neo4j to store embeddings. - (default: "embedding") - text_node_property: The property name in Neo4j to store the text. - (default: "text") - retrieval_query: The Cypher query to be used for customizing retrieval. - If empty, a default query will be used. - index_type: The type of index to be used, either - 'NODE' or 'RELATIONSHIP' - pre_delete_collection: If True, will delete existing data if it exists. - (default: False). Useful for testing. - - Example: - .. code-block:: python - - from langchain_community.vectorstores.neo4j_vector import Neo4jVector - from langchain_community.embeddings.openai import OpenAIEmbeddings - - url="bolt://localhost:7687" - username="neo4j" - password="pleaseletmein" - embeddings = OpenAIEmbeddings() - vectorestore = Neo4jVector.from_documents( - embedding=embeddings, - documents=docs, - url=url - username=username, - password=password, - ) - - - """ - - def __init__( - self, - embedding: Embeddings, - *, - search_type: SearchType = SearchType.VECTOR, - username: Optional[str] = None, - password: Optional[str] = None, - url: Optional[str] = None, - keyword_index_name: Optional[str] = "keyword", - database: Optional[str] = None, - index_name: str = "vector", - node_label: str = "Chunk", - embedding_node_property: str = "embedding", - text_node_property: str = "text", - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - logger: Optional[logging.Logger] = None, - pre_delete_collection: bool = False, - retrieval_query: str = "", - relevance_score_fn: Optional[Callable[[float], float]] = None, - index_type: IndexType = DEFAULT_INDEX_TYPE, - graph: Optional[Neo4jGraph] = None, - ) -> None: - try: - import neo4j - except ImportError: - raise ImportError( - "Could not import neo4j python package. " - "Please install it with `pip install neo4j`." - ) - - # Allow only cosine and euclidean distance strategies - if distance_strategy not in [ - DistanceStrategy.EUCLIDEAN_DISTANCE, - DistanceStrategy.COSINE, - ]: - raise ValueError( - "distance_strategy must be either 'EUCLIDEAN_DISTANCE' or 'COSINE'" - ) - - # Graph object takes precedent over env or input params - if graph: - self._driver = graph._driver - self._database = graph._database - else: - # Handle if the credentials are environment variables - # Support URL for backwards compatibility - if not url: - url = os.environ.get("NEO4J_URL") - - url = get_from_dict_or_env({"url": url}, "url", "NEO4J_URI") - username = get_from_dict_or_env( - {"username": username}, "username", "NEO4J_USERNAME" - ) - password = get_from_dict_or_env( - {"password": password}, "password", "NEO4J_PASSWORD" - ) - database = get_from_dict_or_env( - {"database": database}, "database", "NEO4J_DATABASE", "neo4j" - ) - - self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) - self._database = database - # Verify connection - try: - self._driver.verify_connectivity() - except neo4j.exceptions.ServiceUnavailable: - raise ValueError( - "Could not connect to Neo4j database. " - "Please ensure that the url is correct" - ) - except neo4j.exceptions.AuthError: - raise ValueError( - "Could not connect to Neo4j database. " - "Please ensure that the username and password are correct" - ) - - self.schema = "" - # Verify if the version support vector index - self._is_enterprise = False - self.verify_version() - - # Verify that required values are not null - check_if_not_null( - [ - "index_name", - "node_label", - "embedding_node_property", - "text_node_property", - ], - [index_name, node_label, embedding_node_property, text_node_property], - ) - - self.embedding = embedding - self._distance_strategy = distance_strategy - self.index_name = index_name - self.keyword_index_name = keyword_index_name - self.node_label = node_label - self.embedding_node_property = embedding_node_property - self.text_node_property = text_node_property - self.logger = logger or logging.getLogger(__name__) - self.override_relevance_score_fn = relevance_score_fn - self.retrieval_query = retrieval_query - self.search_type = search_type - self._index_type = index_type - # Calculate embedding dimension - self.embedding_dimension = len(embedding.embed_query("foo")) - - # Delete existing data if flagged - if pre_delete_collection: - from neo4j.exceptions import DatabaseError - - self.query( - f"MATCH (n:`{self.node_label}`) " - "CALL (n) { DETACH DELETE n } " - "IN TRANSACTIONS OF 10000 ROWS;" - ) - # Delete index - try: - self.query(f"DROP INDEX {self.index_name}") - except DatabaseError: # Index didn't exist yet - pass - - def query( - self, - query: str, - *, - params: Optional[dict] = None, - ) -> List[Dict[str, Any]]: - """Query Neo4j database with retries and exponential backoff. - - Args: - query (str): The Cypher query to execute. - params (dict, optional): Dictionary of query parameters. Defaults to {}. - - Returns: - List[Dict[str, Any]]: List of dictionaries containing the query results. - """ - from neo4j import Query - from neo4j.exceptions import Neo4jError - - params = params or {} - try: - data, _, _ = self._driver.execute_query( - query, database_=self._database, parameters_=params - ) - return [r.data() for r in data] - except Neo4jError as e: - if not ( - ( - ( # isCallInTransactionError - e.code == "Neo.DatabaseError.Statement.ExecutionFailed" - or e.code - == "Neo.DatabaseError.Transaction.TransactionStartFailed" - ) - and "in an implicit transaction" in e.message - ) - or ( # isPeriodicCommitError - e.code == "Neo.ClientError.Statement.SemanticError" - and ( - "in an open transaction is not possible" in e.message - or "tried to execute in an explicit transaction" in e.message - ) - ) - ): - raise - # Fallback to allow implicit transactions - with self._driver.session(database=self._database) as session: - data = session.run(Query(text=query), params) - return [r.data() for r in data] - - def verify_version(self) -> None: - """ - Check if the connected Neo4j database version supports vector indexing. - - Queries the Neo4j database to retrieve its version and compares it - against a target version (5.11.0) that is known to support vector - indexing. Raises a ValueError if the connected Neo4j version is - not supported. - """ - db_data = self.query("CALL dbms.components()") - version = db_data[0]["versions"][0] - if "aura" in version: - version_tuple = tuple(map(int, version.split("-")[0].split("."))) + (0,) - else: - version_tuple = tuple(map(int, version.split("."))) - - target_version = (5, 11, 0) - - if version_tuple < target_version: - raise ValueError( - "Version index is only supported in Neo4j version 5.11 or greater" - ) - - # Flag for metadata filtering - metadata_target_version = (5, 18, 0) - if version_tuple < metadata_target_version: - self.support_metadata_filter = False - else: - self.support_metadata_filter = True - # Flag for enterprise - self._is_enterprise = True if db_data[0]["edition"] == "enterprise" else False - - def retrieve_existing_index(self) -> Tuple[Optional[int], Optional[str]]: - """ - Check if the vector index exists in the Neo4j database - and returns its embedding dimension. - - This method queries the Neo4j database for existing indexes - and attempts to retrieve the dimension of the vector index - with the specified name. If the index exists, its dimension is returned. - If the index doesn't exist, `None` is returned. - - Returns: - int or None: The embedding dimension of the existing index if found. - """ - - index_information = self.query( - "SHOW INDEXES YIELD name, type, entityType, labelsOrTypes, " - "properties, options WHERE type = 'VECTOR' AND (name = $index_name " - "OR (labelsOrTypes[0] = $node_label AND " - "properties[0] = $embedding_node_property)) " - "RETURN name, entityType, labelsOrTypes, properties, options ", - params={ - "index_name": self.index_name, - "node_label": self.node_label, - "embedding_node_property": self.embedding_node_property, - }, - ) - # sort by index_name - index_information = sort_by_index_name(index_information, self.index_name) - try: - self.index_name = index_information[0]["name"] - self.node_label = index_information[0]["labelsOrTypes"][0] - self.embedding_node_property = index_information[0]["properties"][0] - self._index_type = index_information[0]["entityType"] - embedding_dimension = None - index_config = index_information[0]["options"]["indexConfig"] - if "vector.dimensions" in index_config: - embedding_dimension = index_config["vector.dimensions"] - - return embedding_dimension, index_information[0]["entityType"] - except IndexError: - return None, None - - def retrieve_existing_fts_index( - self, text_node_properties: List[str] = [] - ) -> Optional[str]: - """ - Check if the fulltext index exists in the Neo4j database - - This method queries the Neo4j database for existing fts indexes - with the specified name. - - Returns: - (Tuple): keyword index information - """ - - index_information = self.query( - "SHOW INDEXES YIELD name, type, labelsOrTypes, properties, options " - "WHERE type = 'FULLTEXT' AND (name = $keyword_index_name " - "OR (labelsOrTypes = [$node_label] AND " - "properties = $text_node_property)) " - "RETURN name, labelsOrTypes, properties, options ", - params={ - "keyword_index_name": self.keyword_index_name, - "node_label": self.node_label, - "text_node_property": text_node_properties or [self.text_node_property], - }, - ) - # sort by index_name - index_information = sort_by_index_name(index_information, self.index_name) - try: - self.keyword_index_name = index_information[0]["name"] - self.text_node_property = index_information[0]["properties"][0] - node_label = index_information[0]["labelsOrTypes"][0] - return node_label - except IndexError: - return None - - def create_new_index(self) -> None: - """ - This method constructs a Cypher query and executes it - to create a new vector index in Neo4j. - """ - index_query = ( - f"CREATE VECTOR INDEX {self.index_name} IF NOT EXISTS " - f"FOR (m:`{self.node_label}`) ON m.`{self.embedding_node_property}` " - "OPTIONS { indexConfig: { " - "`vector.dimensions`: toInteger($embedding_dimension), " - "`vector.similarity_function`: $similarity_metric }}" - ) - - parameters = { - "embedding_dimension": self.embedding_dimension, - "similarity_metric": DISTANCE_MAPPING[self._distance_strategy], - } - self.query(index_query, params=parameters) - - def create_new_keyword_index(self, text_node_properties: List[str] = []) -> None: - """ - This method constructs a Cypher query and executes it - to create a new full text index in Neo4j. - """ - node_props = text_node_properties or [self.text_node_property] - fts_index_query = ( - f"CREATE FULLTEXT INDEX {self.keyword_index_name} " - f"FOR (n:`{self.node_label}`) ON EACH " - f"[{', '.join(['n.`' + el + '`' for el in node_props])}]" - ) - self.query(fts_index_query) - - @property - def embeddings(self) -> Embeddings: - return self.embedding - - @classmethod - def __from( - cls, - texts: List[str], - embeddings: List[List[float]], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - create_id_index: bool = True, - search_type: SearchType = SearchType.VECTOR, - **kwargs: Any, - ) -> Neo4jVector: - if ids is None: - ids = [md5(text.encode("utf-8")).hexdigest() for text in texts] - - if not metadatas: - metadatas = [{} for _ in texts] - - store = cls( - embedding=embedding, - search_type=search_type, - **kwargs, - ) - # Check if the vector index already exists - embedding_dimension, index_type = store.retrieve_existing_index() - - # Raise error if relationship index type - if index_type == "RELATIONSHIP": - raise ValueError( - "Data ingestion is not supported with relationship vector index." - ) - - # If the vector index doesn't exist yet - if not index_type: - store.create_new_index() - # If the index already exists, check if embedding dimensions match - elif ( - embedding_dimension and not store.embedding_dimension == embedding_dimension - ): - raise ValueError( - f"Index with name {store.index_name} already exists." - "The provided embedding function and vector index " - "dimensions do not match.\n" - f"Embedding function dimension: {store.embedding_dimension}\n" - f"Vector index dimension: {embedding_dimension}" - ) - - if search_type == SearchType.HYBRID: - fts_node_label = store.retrieve_existing_fts_index() - # If the FTS index doesn't exist yet - if not fts_node_label: - store.create_new_keyword_index() - else: # Validate that FTS and Vector index use the same information - if not fts_node_label == store.node_label: - raise ValueError( - "Vector and keyword index don't index the same node label" - ) - - # Create unique constraint for faster import - if create_id_index: - store.query( - "CREATE CONSTRAINT IF NOT EXISTS " - f"FOR (n:`{store.node_label}`) REQUIRE n.id IS UNIQUE;" - ) - - store.add_embeddings( - texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs - ) - - return store - - def add_embeddings( - self, - texts: Iterable[str], - embeddings: List[List[float]], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, - ) -> List[str]: - """Add embeddings to the vectorstore. - - Args: - texts: Iterable of strings to add to the vectorstore. - embeddings: List of list of embedding vectors. - metadatas: List of metadatas associated with the texts. - kwargs: vectorstore specific parameters - """ - if ids is None: - ids = [md5(text.encode("utf-8")).hexdigest() for text in texts] - - if not metadatas: - metadatas = [{} for _ in texts] - - import_query = ( - "UNWIND $data AS row " - "CALL (row) { WITH row " - f"MERGE (c:`{self.node_label}` {{id: row.id}}) " - "WITH c, row " - f"CALL db.create.setNodeVectorProperty(c, " - f"'{self.embedding_node_property}', row.embedding) " - f"SET c.`{self.text_node_property}` = row.text " - "SET c += row.metadata " - "} IN TRANSACTIONS OF 1000 ROWS " - ) - - parameters = { - "data": [ - {"text": text, "metadata": metadata, "embedding": embedding, "id": id} - for text, metadata, embedding, id in zip( - texts, metadatas, embeddings, ids - ) - ] - } - - self.query(import_query, params=parameters) - - return ids - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, - ) -> List[str]: - """Run more texts through the embeddings and add to the vectorstore. - - Args: - texts: Iterable of strings to add to the vectorstore. - metadatas: Optional list of metadatas associated with the texts. - kwargs: vectorstore specific parameters - - Returns: - List of ids from adding the texts into the vectorstore. - """ - embeddings = self.embedding.embed_documents(list(texts)) - return self.add_embeddings( - texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs - ) - - def similarity_search( - self, - query: str, - k: int = 4, - params: Dict[str, Any] = {}, - filter: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> List[Document]: - """Run similarity search with Neo4jVector. - - Args: - query (str): Query text to search for. - k (int): Number of results to return. Defaults to 4. - params (Dict[str, Any]): The search params for the index type. - Defaults to empty dict. - filter (Optional[Dict[str, Any]]): Dictionary of argument(s) to - filter on metadata. - Defaults to None. - - Returns: - List of Documents most similar to the query. - """ - embedding = self.embedding.embed_query(text=query) - return self.similarity_search_by_vector( - embedding=embedding, - k=k, - query=query, - params=params, - filter=filter, - **kwargs, - ) - - def similarity_search_with_score( - self, - query: str, - k: int = 4, - params: Dict[str, Any] = {}, - filter: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - params (Dict[str, Any]): The search params for the index type. - Defaults to empty dict. - filter (Optional[Dict[str, Any]]): Dictionary of argument(s) to - filter on metadata. - Defaults to None. - - Returns: - List of Documents most similar to the query and score for each - """ - embedding = self.embedding.embed_query(query) - docs = self.similarity_search_with_score_by_vector( - embedding=embedding, - k=k, - query=query, - params=params, - filter=filter, - **kwargs, - ) - return docs - - def similarity_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[Dict[str, Any]] = None, - params: Dict[str, Any] = {}, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """ - Perform a similarity search in the Neo4j database using a - given vector and return the top k similar documents with their scores. - - This method uses a Cypher query to find the top k documents that - are most similar to a given embedding. The similarity is measured - using a vector index in the Neo4j database. The results are returned - as a list of tuples, each containing a Document object and - its similarity score. - - Args: - embedding (List[float]): The embedding vector to compare against. - k (int, optional): The number of top similar documents to retrieve. - filter (Optional[Dict[str, Any]]): Dictionary of argument(s) to - filter on metadata. - Defaults to None. - params (Dict[str, Any]): The search params for the index type. - Defaults to empty dict. - - Returns: - List[Tuple[Document, float]]: A list of tuples, each containing - a Document object and its similarity score. - """ - if filter: - # Verify that 5.18 or later is used - if not self.support_metadata_filter: - raise ValueError( - "Metadata filtering is only supported in " - "Neo4j version 5.18 or greater" - ) - # Metadata filtering and hybrid doesn't work - if self.search_type == SearchType.HYBRID: - raise ValueError( - "Metadata filtering can't be use in combination with " - "a hybrid search approach" - ) - parallel_query = ( - "CYPHER runtime = parallel parallelRuntimeSupport=all " - if self._is_enterprise - else "" - ) - base_index_query = parallel_query + ( - f"MATCH (n:`{self.node_label}`) WHERE " - f"n.`{self.embedding_node_property}` IS NOT NULL AND " - f"size(n.`{self.embedding_node_property}`) = " - f"toInteger({self.embedding_dimension}) AND " - ) - base_cosine_query = ( - " WITH n as node, vector.similarity.cosine(" - f"n.`{self.embedding_node_property}`, " - "$embedding) AS score ORDER BY score DESC LIMIT toInteger($k) " - ) - filter_snippets, filter_params = construct_metadata_filter(filter) - index_query = base_index_query + filter_snippets + base_cosine_query - - else: - index_query = _get_search_index_query(self.search_type, self._index_type) - filter_params = {} - - if self._index_type == IndexType.RELATIONSHIP: - if kwargs.get("return_embeddings"): - default_retrieval = ( - f"RETURN relationship.`{self.text_node_property}` AS text, score, " - f"relationship {{.*, `{self.text_node_property}`: Null, " - f"`{self.embedding_node_property}`: Null, id: Null, " - f"_embedding_: relationship.`{self.embedding_node_property}`}} " - "AS metadata" - ) - else: - default_retrieval = ( - f"RETURN relationship.`{self.text_node_property}` AS text, score, " - f"relationship {{.*, `{self.text_node_property}`: Null, " - f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata" - ) - - else: - if kwargs.get("return_embeddings"): - default_retrieval = ( - f"RETURN node.`{self.text_node_property}` AS text, score, " - f"node {{.*, `{self.text_node_property}`: Null, " - f"`{self.embedding_node_property}`: Null, id: Null, " - f"_embedding_: node.`{self.embedding_node_property}`}} AS metadata" - ) - else: - default_retrieval = ( - f"RETURN node.`{self.text_node_property}` AS text, score, " - f"node {{.*, `{self.text_node_property}`: Null, " - f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata" - ) - - retrieval_query = ( - self.retrieval_query if self.retrieval_query else default_retrieval - ) - - read_query = index_query + retrieval_query - parameters = { - "index": self.index_name, - "k": k, - "embedding": embedding, - "keyword_index": self.keyword_index_name, - "query": remove_lucene_chars(kwargs["query"]), - **params, - **filter_params, - } - - results = self.query(read_query, params=parameters) - - if any(result["text"] is None for result in results): - if not self.retrieval_query: - raise ValueError( - f"Make sure that none of the `{self.text_node_property}` " - f"properties on nodes with label `{self.node_label}` " - "are missing or empty" - ) - else: - raise ValueError( - "Inspect the `retrieval_query` and ensure it doesn't " - "return None for the `text` column" - ) - if kwargs.get("return_embeddings") and any( - result["metadata"]["_embedding_"] is None for result in results - ): - if not self.retrieval_query: - raise ValueError( - f"Make sure that none of the `{self.embedding_node_property}` " - f"properties on nodes with label `{self.node_label}` " - "are missing or empty" - ) - else: - raise ValueError( - "Inspect the `retrieval_query` and ensure it doesn't " - "return None for the `_embedding_` metadata column" - ) - - docs = [ - ( - Document( - page_content=dict_to_yaml_str(result["text"]) - if isinstance(result["text"], dict) - else result["text"], - metadata={ - k: v for k, v in result["metadata"].items() if v is not None - }, - ), - result["score"], - ) - for result in results - ] - return docs - - def similarity_search_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[Dict[str, Any]] = None, - params: Dict[str, Any] = {}, - **kwargs: Any, - ) -> List[Document]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter (Optional[Dict[str, Any]]): Dictionary of argument(s) to - filter on metadata. - Defaults to None. - params (Dict[str, Any]): The search params for the index type. - Defaults to empty dict. - - Returns: - List of Documents most similar to the query vector. - """ - docs_and_scores = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter, params=params, **kwargs - ) - return [doc for doc, _ in docs_and_scores] - - @classmethod - def from_texts( - cls: Type[Neo4jVector], - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - **kwargs: Any, - ) -> Neo4jVector: - """ - Return Neo4jVector initialized from texts and embeddings. - Neo4j credentials are required in the form of `url`, `username`, - and `password` and optional `database` parameters. - """ - embeddings = embedding.embed_documents(list(texts)) - - return cls.__from( - texts, - embeddings, - embedding, - metadatas=metadatas, - ids=ids, - distance_strategy=distance_strategy, - **kwargs, - ) - - @classmethod - def from_embeddings( - cls, - text_embeddings: List[Tuple[str, List[float]]], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - pre_delete_collection: bool = False, - **kwargs: Any, - ) -> Neo4jVector: - """Construct Neo4jVector wrapper from raw documents and pre- - generated embeddings. - - Return Neo4jVector initialized from documents and embeddings. - Neo4j credentials are required in the form of `url`, `username`, - and `password` and optional `database` parameters. - - Example: - .. code-block:: python - - from langchain_community.vectorstores.neo4j_vector import Neo4jVector - from langchain_community.embeddings import OpenAIEmbeddings - embeddings = OpenAIEmbeddings() - text_embeddings = embeddings.embed_documents(texts) - text_embedding_pairs = list(zip(texts, text_embeddings)) - vectorstore = Neo4jVector.from_embeddings( - text_embedding_pairs, embeddings) - """ - texts = [t[0] for t in text_embeddings] - embeddings = [t[1] for t in text_embeddings] - - return cls.__from( - texts, - embeddings, - embedding, - metadatas=metadatas, - ids=ids, - distance_strategy=distance_strategy, - pre_delete_collection=pre_delete_collection, - **kwargs, - ) - - @classmethod - def from_existing_index( - cls: Type[Neo4jVector], - embedding: Embeddings, - index_name: str, - search_type: SearchType = DEFAULT_SEARCH_TYPE, - keyword_index_name: Optional[str] = None, - **kwargs: Any, - ) -> Neo4jVector: - """ - Get instance of an existing Neo4j vector index. This method will - return the instance of the store without inserting any new - embeddings. - Neo4j credentials are required in the form of `url`, `username`, - and `password` and optional `database` parameters along with - the `index_name` definition. - """ - - if search_type == SearchType.HYBRID and not keyword_index_name: - raise ValueError( - "keyword_index name has to be specified when using hybrid search option" - ) - - store = cls( - embedding=embedding, - index_name=index_name, - keyword_index_name=keyword_index_name, - search_type=search_type, - **kwargs, - ) - - embedding_dimension, index_type = store.retrieve_existing_index() - - # Raise error if relationship index type - if index_type == "RELATIONSHIP": - raise ValueError( - "Relationship vector index is not supported with " - "`from_existing_index` method. Please use the " - "`from_existing_relationship_index` method." - ) - - if not index_type: - raise ValueError( - "The specified vector index name does not exist. " - "Make sure to check if you spelled it correctly" - ) - - # Check if embedding function and vector index dimensions match - if embedding_dimension and not store.embedding_dimension == embedding_dimension: - raise ValueError( - "The provided embedding function and vector index " - "dimensions do not match.\n" - f"Embedding function dimension: {store.embedding_dimension}\n" - f"Vector index dimension: {embedding_dimension}" - ) - - if search_type == SearchType.HYBRID: - fts_node_label = store.retrieve_existing_fts_index() - # If the FTS index doesn't exist yet - if not fts_node_label: - raise ValueError( - "The specified keyword index name does not exist. " - "Make sure to check if you spelled it correctly" - ) - else: # Validate that FTS and Vector index use the same information - if not fts_node_label == store.node_label: - raise ValueError( - "Vector and keyword index don't index the same node label" - ) - - return store - - @classmethod - def from_existing_relationship_index( - cls: Type[Neo4jVector], - embedding: Embeddings, - index_name: str, - search_type: SearchType = DEFAULT_SEARCH_TYPE, - **kwargs: Any, - ) -> Neo4jVector: - """ - Get instance of an existing Neo4j relationship vector index. - This method will return the instance of the store without - inserting any new embeddings. - Neo4j credentials are required in the form of `url`, `username`, - and `password` and optional `database` parameters along with - the `index_name` definition. - """ - - if search_type == SearchType.HYBRID: - raise ValueError( - "Hybrid search is not supported in combination " - "with relationship vector index" - ) - - store = cls( - embedding=embedding, - index_name=index_name, - **kwargs, - ) - - embedding_dimension, index_type = store.retrieve_existing_index() - - if not index_type: - raise ValueError( - "The specified vector index name does not exist. " - "Make sure to check if you spelled it correctly" - ) - # Raise error if relationship index type - if index_type == "NODE": - raise ValueError( - "Node vector index is not supported with " - "`from_existing_relationship_index` method. Please use the " - "`from_existing_index` method." - ) - - # Check if embedding function and vector index dimensions match - if embedding_dimension and not store.embedding_dimension == embedding_dimension: - raise ValueError( - "The provided embedding function and vector index " - "dimensions do not match.\n" - f"Embedding function dimension: {store.embedding_dimension}\n" - f"Vector index dimension: {embedding_dimension}" - ) - - return store - - @classmethod - def from_documents( - cls: Type[Neo4jVector], - documents: List[Document], - embedding: Embeddings, - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - ids: Optional[List[str]] = None, - **kwargs: Any, - ) -> Neo4jVector: - """ - Return Neo4jVector initialized from documents and embeddings. - Neo4j credentials are required in the form of `url`, `username`, - and `password` and optional `database` parameters. - """ - - texts = [d.page_content for d in documents] - metadatas = [d.metadata for d in documents] - - return cls.from_texts( - texts=texts, - embedding=embedding, - distance_strategy=distance_strategy, - metadatas=metadatas, - ids=ids, - **kwargs, - ) - - @classmethod - def from_existing_graph( - cls: Type[Neo4jVector], - embedding: Embeddings, - node_label: str, - embedding_node_property: str, - text_node_properties: List[str], - *, - keyword_index_name: Optional[str] = "keyword", - index_name: str = "vector", - search_type: SearchType = DEFAULT_SEARCH_TYPE, - retrieval_query: str = "", - **kwargs: Any, - ) -> Neo4jVector: - """ - Initialize and return a Neo4jVector instance from an existing graph. - - This method initializes a Neo4jVector instance using the provided - parameters and the existing graph. It validates the existence of - the indices and creates new ones if they don't exist. - - Returns: - Neo4jVector: An instance of Neo4jVector initialized with the provided parameters - and existing graph. - - Example: - >>> neo4j_vector = Neo4jVector.from_existing_graph( - ... embedding=my_embedding, - ... node_label="Document", - ... embedding_node_property="embedding", - ... text_node_properties=["title", "content"] - ... ) - - Note: - Neo4j credentials are required in the form of `url`, `username`, and `password`, - and optional `database` parameters passed as additional keyword arguments. - """ - # Validate the list is not empty - if not text_node_properties: - raise ValueError( - "Parameter `text_node_properties` must not be an empty list" - ) - # Prefer retrieval query from params, otherwise construct it - if not retrieval_query: - retrieval_query = ( - f"RETURN reduce(str='', k IN {text_node_properties} |" - " str + '\\n' + k + ': ' + coalesce(node[k], '')) AS text, " - "node {.*, `" - + embedding_node_property - + "`: Null, id: Null, " - + ", ".join([f"`{prop}`: Null" for prop in text_node_properties]) - + "} AS metadata, score" - ) - store = cls( - embedding=embedding, - index_name=index_name, - keyword_index_name=keyword_index_name, - search_type=search_type, - retrieval_query=retrieval_query, - node_label=node_label, - embedding_node_property=embedding_node_property, - **kwargs, - ) - - # Check if the vector index already exists - embedding_dimension, index_type = store.retrieve_existing_index() - - # Raise error if relationship index type - if index_type == "RELATIONSHIP": - raise ValueError( - "`from_existing_graph` method does not support " - " existing relationship vector index. " - "Please use `from_existing_relationship_index` method" - ) - - # If the vector index doesn't exist yet - if not index_type: - store.create_new_index() - # If the index already exists, check if embedding dimensions match - elif ( - embedding_dimension and not store.embedding_dimension == embedding_dimension - ): - raise ValueError( - f"Index with name {store.index_name} already exists." - "The provided embedding function and vector index " - "dimensions do not match.\n" - f"Embedding function dimension: {store.embedding_dimension}\n" - f"Vector index dimension: {embedding_dimension}" - ) - # FTS index for Hybrid search - if search_type == SearchType.HYBRID: - fts_node_label = store.retrieve_existing_fts_index(text_node_properties) - # If the FTS index doesn't exist yet - if not fts_node_label: - store.create_new_keyword_index(text_node_properties) - else: # Validate that FTS and Vector index use the same information - if not fts_node_label == store.node_label: - raise ValueError( - "Vector and keyword index don't index the same node label" - ) - - # Populate embeddings - while True: - fetch_query = ( - f"MATCH (n:`{node_label}`) " - f"WHERE n.{embedding_node_property} IS null " - "AND any(k in $props WHERE n[k] IS NOT null) " - f"RETURN elementId(n) AS id, reduce(str=''," - "k IN $props | str + '\\n' + k + ':' + coalesce(n[k], '')) AS text " - "LIMIT 1000" - ) - data = store.query(fetch_query, params={"props": text_node_properties}) - if not data: - break - text_embeddings = embedding.embed_documents([el["text"] for el in data]) - - params = { - "data": [ - {"id": el["id"], "embedding": embedding} - for el, embedding in zip(data, text_embeddings) - ] - } - - store.query( - "UNWIND $data AS row " - f"MATCH (n:`{node_label}`) " - "WHERE elementId(n) = row.id " - f"CALL db.create.setNodeVectorProperty(n, " - f"'{embedding_node_property}', row.embedding) " - "RETURN count(*)", - params=params, - ) - # If embedding calculation should be stopped - if len(data) < 1000: - break - return store - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[dict] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: search query text. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter: Filter on metadata properties, e.g. - { - "str_property": "foo", - "int_property": 123 - } - Returns: - List of Documents selected by maximal marginal relevance. - """ - # Embed the query - query_embedding = self.embedding.embed_query(query) - - # Fetch the initial documents - got_docs = self.similarity_search_with_score_by_vector( - embedding=query_embedding, - query=query, - k=fetch_k, - return_embeddings=True, - filter=filter, - **kwargs, - ) - - # Get the embeddings for the fetched documents - got_embeddings = [doc.metadata["_embedding_"] for doc, _ in got_docs] - - # Select documents using maximal marginal relevance - selected_indices = maximal_marginal_relevance( - np.array(query_embedding), got_embeddings, lambda_mult=lambda_mult, k=k - ) - selected_docs = [got_docs[i][0] for i in selected_indices] - - # Remove embedding values from metadata - for doc in selected_docs: - del doc.metadata["_embedding_"] - - return selected_docs - - def _select_relevance_score_fn(self) -> Callable[[float], float]: - """ - The 'correct' relevance function - may differ depending on a few things, including: - - the distance / similarity metric used by the VectorStore - - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) - - embedding dimensionality - - etc. - """ - if self.override_relevance_score_fn is not None: - return self.override_relevance_score_fn - - # Default strategy is to rely on distance strategy provided - # in vectorstore constructor - if self._distance_strategy == DistanceStrategy.COSINE: - return lambda x: x - elif self._distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE: - return lambda x: x - else: - raise ValueError( - "No supported normalization function" - f" for distance_strategy of {self._distance_strategy}." - "Consider providing relevance_score_fn to PGVector constructor." - ) diff --git a/libs/community/langchain_community/vectorstores/oraclevs.py b/libs/community/langchain_community/vectorstores/oraclevs.py deleted file mode 100644 index e0b1afca0..000000000 --- a/libs/community/langchain_community/vectorstores/oraclevs.py +++ /dev/null @@ -1,1090 +0,0 @@ -from __future__ import annotations - -import array -import functools -import hashlib -import json -import logging -import os -import uuid -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - List, - Optional, - Tuple, - Type, - TypeVar, - Union, - cast, -) - -from numpy.typing import NDArray - -if TYPE_CHECKING: - from oracledb import Connection - -import numpy as np -from langchain_core._api import deprecated -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from langchain_core.vectorstores import VectorStore - -from langchain_community.vectorstores.utils import ( - DistanceStrategy, - maximal_marginal_relevance, -) - -logger = logging.getLogger(__name__) -log_level = os.getenv("LOG_LEVEL", "ERROR").upper() -logging.basicConfig( - level=getattr(logging, log_level), - format="%(asctime)s - %(levelname)s - %(message)s", -) - - -# Define a type variable that can be any kind of function -T = TypeVar("T", bound=Callable[..., Any]) - - -def _get_connection(client: Any) -> Connection | None: - # Dynamically import oracledb and the required classes - try: - import oracledb - except ImportError as e: - raise ImportError( - "Unable to import oracledb, please install with `pip install -U oracledb`." - ) from e - - # check if ConnectionPool exists - connection_pool_class = getattr(oracledb, "ConnectionPool", None) - - if isinstance(client, oracledb.Connection): - return client - elif connection_pool_class and isinstance(client, connection_pool_class): - return client.acquire() - else: - valid_types = "oracledb.Connection" - if connection_pool_class: - valid_types += " or oracledb.ConnectionPool" - raise TypeError( - f"Expected client of type {valid_types}, got {type(client).__name__}" - ) - - -def _handle_exceptions(func: T) -> T: - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - try: - return func(*args, **kwargs) - except RuntimeError as db_err: - # Handle a known type of error (e.g., DB-related) specifically - logger.exception("DB-related error occurred.") - raise RuntimeError( - "Failed due to a DB issue: {}".format(db_err) - ) from db_err - except ValueError as val_err: - # Handle another known type of error specifically - logger.exception("Validation error.") - raise ValueError("Validation failed: {}".format(val_err)) from val_err - except Exception as e: - # Generic handler for all other exceptions - logger.exception("An unexpected error occurred: {}".format(e)) - raise RuntimeError("Unexpected error: {}".format(e)) from e - - return cast(T, wrapper) - - -def _table_exists(connection: Connection, table_name: str) -> bool: - try: - import oracledb - except ImportError as e: - raise ImportError( - "Unable to import oracledb, please install with `pip install -U oracledb`." - ) from e - - try: - with connection.cursor() as cursor: - cursor.execute(f"SELECT COUNT(*) FROM {table_name}") - return True - except oracledb.DatabaseError as ex: - err_obj = ex.args - if err_obj[0].code == 942: - return False - raise - - -def _compare_version(version: str, target_version: str) -> bool: - # Split both version strings into parts - version_parts = [int(part) for part in version.split(".")] - target_parts = [int(part) for part in target_version.split(".")] - - # Compare each part - for v, t in zip(version_parts, target_parts): - if v < t: - return True # Current version is less - elif v > t: - return False # Current version is greater - - # If all parts equal so far, check if version has fewer parts than target_version - return len(version_parts) < len(target_parts) - - -@_handle_exceptions -def _index_exists(connection: Connection, index_name: str) -> bool: - # Check if the index exists - query = """ - SELECT index_name - FROM all_indexes - WHERE upper(index_name) = upper(:idx_name) - """ - - with connection.cursor() as cursor: - # Execute the query - cursor.execute(query, idx_name=index_name.upper()) - result = cursor.fetchone() - - # Check if the index exists - return result is not None - - -def _get_distance_function(distance_strategy: DistanceStrategy) -> str: - # Dictionary to map distance strategies to their corresponding function - # names - distance_strategy2function = { - DistanceStrategy.EUCLIDEAN_DISTANCE: "EUCLIDEAN", - DistanceStrategy.DOT_PRODUCT: "DOT", - DistanceStrategy.COSINE: "COSINE", - } - - # Attempt to return the corresponding distance function - if distance_strategy in distance_strategy2function: - return distance_strategy2function[distance_strategy] - - # If it's an unsupported distance strategy, raise an error - raise ValueError(f"Unsupported distance strategy: {distance_strategy}") - - -def _get_index_name(base_name: str) -> str: - unique_id = str(uuid.uuid4()).replace("-", "") - return f"{base_name}_{unique_id}" - - -@_handle_exceptions -def _create_table(connection: Connection, table_name: str, embedding_dim: int) -> None: - cols_dict = { - "id": "RAW(16) DEFAULT SYS_GUID() PRIMARY KEY", - "text": "CLOB", - "metadata": "JSON", - "embedding": f"vector({embedding_dim}, FLOAT32)", - } - - if not _table_exists(connection, table_name): - with connection.cursor() as cursor: - ddl_body = ", ".join( - f"{col_name} {col_type}" for col_name, col_type in cols_dict.items() - ) - ddl = f"CREATE TABLE {table_name} ({ddl_body})" - cursor.execute(ddl) - logger.info("Table created successfully...") - else: - logger.info("Table already exists...") - - -@_handle_exceptions -def create_index( - client: Any, - vector_store: OracleVS, - params: Optional[dict[str, Any]] = None, -) -> None: - connection = _get_connection(client) - if connection is None: - raise ValueError("Failed to acquire a connection.") - if params: - if params["idx_type"] == "HNSW": - _create_hnsw_index( - connection, - vector_store.table_name, - vector_store.distance_strategy, - params, - ) - elif params["idx_type"] == "IVF": - _create_ivf_index( - connection, - vector_store.table_name, - vector_store.distance_strategy, - params, - ) - else: - _create_hnsw_index( - connection, - vector_store.table_name, - vector_store.distance_strategy, - params, - ) - else: - _create_hnsw_index( - connection, vector_store.table_name, vector_store.distance_strategy, params - ) - return - - -@_handle_exceptions -def _create_hnsw_index( - connection: Connection, - table_name: str, - distance_strategy: DistanceStrategy, - params: Optional[dict[str, Any]] = None, -) -> None: - defaults = { - "idx_name": "HNSW", - "idx_type": "HNSW", - "neighbors": 32, - "efConstruction": 200, - "accuracy": 90, - "parallel": 8, - } - - if params: - config = params.copy() - # Ensure compulsory parts are included - for compulsory_key in ["idx_name", "parallel"]: - if compulsory_key not in config: - if compulsory_key == "idx_name": - config[compulsory_key] = _get_index_name( - str(defaults[compulsory_key]) - ) - else: - config[compulsory_key] = defaults[compulsory_key] - - # Validate keys in config against defaults - for key in config: - if key not in defaults: - raise ValueError(f"Invalid parameter: {key}") - else: - config = defaults - - # Base SQL statement - idx_name = config["idx_name"] - base_sql = ( - f"create vector index {idx_name} on {table_name}(embedding) " - f"ORGANIZATION INMEMORY NEIGHBOR GRAPH" - ) - - # Optional parts depending on parameters - accuracy_part = " WITH TARGET ACCURACY {accuracy}" if ("accuracy" in config) else "" - distance_part = f" DISTANCE {_get_distance_function(distance_strategy)}" - - parameters_part = "" - if "neighbors" in config and "efConstruction" in config: - parameters_part = ( - " parameters (type {idx_type}, neighbors {" - "neighbors}, efConstruction {efConstruction})" - ) - elif "neighbors" in config and "efConstruction" not in config: - config["efConstruction"] = defaults["efConstruction"] - parameters_part = ( - " parameters (type {idx_type}, neighbors {" - "neighbors}, efConstruction {efConstruction})" - ) - elif "neighbors" not in config and "efConstruction" in config: - config["neighbors"] = defaults["neighbors"] - parameters_part = ( - " parameters (type {idx_type}, neighbors {" - "neighbors}, efConstruction {efConstruction})" - ) - - # Always included part for parallel - parallel_part = " parallel {parallel}" - - # Combine all parts - ddl_assembly = ( - base_sql + accuracy_part + distance_part + parameters_part + parallel_part - ) - # Format the SQL with values from the params dictionary - ddl = ddl_assembly.format(**config) - - # Check if the index exists - if not _index_exists(connection, config["idx_name"]): - with connection.cursor() as cursor: - cursor.execute(ddl) - logger.info("Index created successfully...") - else: - logger.info("Index already exists...") - - -@_handle_exceptions -def _create_ivf_index( - connection: Connection, - table_name: str, - distance_strategy: DistanceStrategy, - params: Optional[dict[str, Any]] = None, -) -> None: - # Default configuration - defaults = { - "idx_name": "IVF", - "idx_type": "IVF", - "neighbor_part": 32, - "accuracy": 90, - "parallel": 8, - } - - if params: - config = params.copy() - # Ensure compulsory parts are included - for compulsory_key in ["idx_name", "parallel"]: - if compulsory_key not in config: - if compulsory_key == "idx_name": - config[compulsory_key] = _get_index_name( - str(defaults[compulsory_key]) - ) - else: - config[compulsory_key] = defaults[compulsory_key] - - # Validate keys in config against defaults - for key in config: - if key not in defaults: - raise ValueError(f"Invalid parameter: {key}") - else: - config = defaults - - # Base SQL statement - idx_name = config["idx_name"] - base_sql = ( - f"CREATE VECTOR INDEX {idx_name} ON {table_name}(embedding) " - f"ORGANIZATION NEIGHBOR PARTITIONS" - ) - - # Optional parts depending on parameters - accuracy_part = " WITH TARGET ACCURACY {accuracy}" if ("accuracy" in config) else "" - distance_part = f" DISTANCE {_get_distance_function(distance_strategy)}" - - parameters_part = "" - if "idx_type" in config and "neighbor_part" in config: - parameters_part = ( - f" PARAMETERS (type {config['idx_type']}, neighbor" - f" partitions {config['neighbor_part']})" - ) - - # Always included part for parallel - parallel_part = f" PARALLEL {config['parallel']}" - - # Combine all parts - ddl_assembly = ( - base_sql + accuracy_part + distance_part + parameters_part + parallel_part - ) - # Format the SQL with values from the params dictionary - ddl = ddl_assembly.format(**config) - - # Check if the index exists - if not _index_exists(connection, config["idx_name"]): - with connection.cursor() as cursor: - cursor.execute(ddl) - logger.info("Index created successfully...") - else: - logger.info("Index already exists...") - - -@_handle_exceptions -def drop_table_purge(client: Any, table_name: str) -> None: - """Drop a table and purge it from the database. - - Args: - client: The OracleDB connection object. - table_name: The name of the table to drop. - - Raises: - RuntimeError: If an error occurs while dropping the table. - """ - connection = _get_connection(client) - if connection is None: - raise ValueError("Failed to acquire a connection.") - if _table_exists(connection, table_name): - with connection.cursor() as cursor: - ddl = f"DROP TABLE {table_name} PURGE" - cursor.execute(ddl) - logger.info("Table dropped successfully...") - else: - logger.info("Table not found...") - return - - -@_handle_exceptions -def drop_index_if_exists(client: Any, index_name: str) -> None: - """Drop an index if it exists. - - Args: - client: The OracleDB connection object. - index_name: The name of the index to drop. - - Raises: - RuntimeError: If an error occurs while dropping the index. - """ - connection = _get_connection(client) - if connection is None: - raise ValueError("Failed to acquire a connection.") - if _index_exists(connection, index_name): - drop_query = f"DROP INDEX {index_name}" - with connection.cursor() as cursor: - cursor.execute(drop_query) - logger.info(f"Index {index_name} has been dropped.") - else: - logger.exception(f"Index {index_name} does not exist.") - return - - -@deprecated( - since="0.3.30", - removal="1.0", - message=( - "This class is deprecated and will be removed in a future release. " - "Instead, please use `OracleVS` from the " - "`langchain-oracledb` package. " - "For more information, refer to ." - ), - alternative="from langchain_oracledb.vectorstores import OracleVS;", - pending=False, -) -class OracleVS(VectorStore): - """`OracleVS` vector store. - - To use, you should have both: - - the ``oracledb`` python package installed - - a connection string associated with a OracleDBCluster having deployed an - Search index - - Example: - .. code-block:: python - - from langchain_classic.vectorstores import OracleVS - from langchain_classic.embeddings.openai import OpenAIEmbeddings - import oracledb - - with oracledb.connect(user = user, passwd = pwd, dsn = dsn) as - connection: - print ("Database version:", connection.version) - embeddings = OpenAIEmbeddings() - query = "" - vectors = OracleVS(connection, table_name, embeddings, query) - """ - - def __init__( - self, - client: Any, - embedding_function: Union[ - Callable[[str], List[float]], - Embeddings, - ], - table_name: str, - distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE, - query: Optional[str] = "What is a Oracle database", - params: Optional[Dict[str, Any]] = None, - ): - try: - import oracledb - except ImportError as e: - raise ImportError( - "Unable to import oracledb, please install with " - "`pip install -U oracledb`." - ) from e - - self.insert_mode = "array" - connection = _get_connection(client) - if connection is None: - raise ValueError("Failed to acquire a connection.") - - if hasattr(connection, "thin") and connection.thin: - if oracledb.__version__ == "2.1.0": - raise Exception( - "Oracle DB python thin client driver version 2.1.0 not supported" - ) - elif _compare_version(oracledb.__version__, "2.2.0"): - self.insert_mode = "clob" - else: - self.insert_mode = "array" - else: - if (_compare_version(oracledb.__version__, "2.1.0")) and ( - not ( - _compare_version( - ".".join(map(str, oracledb.clientversion())), "23.4" - ) - ) - ): - raise Exception( - "Oracle DB python thick client driver version earlier than " - "2.1.0 not supported with client libraries greater than " - "equal to 23.4" - ) - - if _compare_version(".".join(map(str, oracledb.clientversion())), "23.4"): - self.insert_mode = "clob" - else: - self.insert_mode = "array" - - if _compare_version(oracledb.__version__, "2.1.0"): - self.insert_mode = "clob" - - try: - """Initialize with oracledb client.""" - self.client = client - """Initialize with necessary components.""" - if not isinstance(embedding_function, Embeddings): - logger.warning( - "`embedding_function` is expected to be an Embeddings " - "object, support " - "for passing in a function will soon be removed." - ) - self.embedding_function = embedding_function - self.query = query - embedding_dim = self.get_embedding_dimension() - - self.table_name = table_name - self.distance_strategy = distance_strategy - self.params = params - _create_table(connection, table_name, embedding_dim) - except oracledb.DatabaseError as db_err: - logger.exception(f"Database error occurred while create table: {db_err}") - raise RuntimeError( - "Failed to create table due to a database error." - ) from db_err - except ValueError as val_err: - logger.exception(f"Validation error: {val_err}") - raise RuntimeError( - "Failed to create table due to a validation error." - ) from val_err - except Exception as ex: - logger.exception("An unexpected error occurred while creating the index.") - raise RuntimeError( - "Failed to create table due to an unexpected error." - ) from ex - - @property - def embeddings(self) -> Optional[Embeddings]: - """ - A property that returns an Embeddings instance embedding_function - is an instance of Embeddings, otherwise returns None. - - Returns: - Optional[Embeddings]: The embedding function if it's an instance of - Embeddings, otherwise None. - """ - return ( - self.embedding_function - if isinstance(self.embedding_function, Embeddings) - else None - ) - - def get_embedding_dimension(self) -> int: - # Embed the single document by wrapping it in a list - embedded_document = self._embed_documents( - [self.query if self.query is not None else ""] - ) - - # Get the first (and only) embedding's dimension - return len(embedded_document[0]) - - def _embed_documents(self, texts: List[str]) -> List[List[float]]: - if isinstance(self.embedding_function, Embeddings): - return self.embedding_function.embed_documents(texts) - elif callable(self.embedding_function): - return [self.embedding_function(text) for text in texts] - else: - raise TypeError( - "The embedding_function is neither Embeddings nor callable." - ) - - def _embed_query(self, text: str) -> List[float]: - if isinstance(self.embedding_function, Embeddings): - return self.embedding_function.embed_query(text) - else: - return self.embedding_function(text) - - @_handle_exceptions - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[Dict[Any, Any]]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, - ) -> List[str]: - """Add more texts to the vectorstore index. - Args: - texts: Iterable of strings to add to the vectorstore. - metadatas: Optional list of metadatas associated with the texts. - ids: Optional list of ids for the texts that are being added to - the vector store. - kwargs: vectorstore specific parameters - """ - - texts = list(texts) - if ids: - # If ids are provided, hash them to maintain consistency - processed_ids = [ - hashlib.sha256(_id.encode()).hexdigest()[:16].upper() for _id in ids - ] - elif metadatas and all("id" in metadata for metadata in metadatas): - # If no ids are provided but metadatas with ids are, generate - # ids from metadatas - processed_ids = [ - hashlib.sha256(metadata["id"].encode()).hexdigest()[:16].upper() - for metadata in metadatas - ] - else: - # Generate new ids if none are provided - generated_ids = [ - str(uuid.uuid4()) for _ in texts - ] # uuid4 is more standard for random UUIDs - processed_ids = [ - hashlib.sha256(_id.encode()).hexdigest()[:16].upper() - for _id in generated_ids - ] - - embeddings = self._embed_documents(texts) - if not metadatas: - metadatas = [{} for _ in texts] - - docs: List[Tuple[Any, Any, Any, Any]] - if self.insert_mode == "clob": - docs = [ - (id_, json.dumps(embedding), json.dumps(metadata), text) - for id_, embedding, metadata, text in zip( - processed_ids, embeddings, metadatas, texts - ) - ] - else: - docs = [ - (id_, array.array("f", embedding), json.dumps(metadata), text) - for id_, embedding, metadata, text in zip( - processed_ids, embeddings, metadatas, texts - ) - ] - - connection = _get_connection(self.client) - if connection is None: - raise ValueError("Failed to acquire a connection.") - with connection.cursor() as cursor: - cursor.executemany( - f"INSERT INTO {self.table_name} (id, embedding, metadata, " - f"text) VALUES (:1, :2, :3, :4)", - docs, - ) - connection.commit() - return processed_ids - - def similarity_search( - self, - query: str, - k: int = 4, - filter: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs most similar to query.""" - embedding: List[float] = [] - if isinstance(self.embedding_function, Embeddings): - embedding = self.embedding_function.embed_query(query) - documents = self.similarity_search_by_vector( - embedding=embedding, k=k, filter=filter, **kwargs - ) - return documents - - def similarity_search_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict[str, Any]] = None, - **kwargs: Any, - ) -> List[Document]: - docs_and_scores = self.similarity_search_by_vector_with_relevance_scores( - embedding=embedding, k=k, filter=filter, **kwargs - ) - return [doc for doc, _ in docs_and_scores] - - def similarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[dict[str, Any]] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs most similar to query.""" - embedding: List[float] = [] - if isinstance(self.embedding_function, Embeddings): - embedding = self.embedding_function.embed_query(query) - docs_and_scores = self.similarity_search_by_vector_with_relevance_scores( - embedding=embedding, k=k, filter=filter, **kwargs - ) - return docs_and_scores - - @_handle_exceptions - def _get_clob_value(self, result: Any) -> str: - try: - import oracledb - except ImportError as e: - raise ImportError( - "Unable to import oracledb, please install with " - "`pip install -U oracledb`." - ) from e - - clob_value = "" - if result: - if isinstance(result, oracledb.LOB): - raw_data = result.read() - if isinstance(raw_data, bytes): - clob_value = raw_data.decode( - "utf-8" - ) # Specify the correct encoding - else: - clob_value = raw_data - elif isinstance(result, str): - clob_value = result - else: - raise Exception("Unexpected type:", type(result)) - return clob_value - - @_handle_exceptions - def similarity_search_by_vector_with_relevance_scores( - self, - embedding: List[float], - k: int = 4, - filter: Optional[dict[str, Any]] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - docs_and_scores = [] - - embedding_arr: Any - if self.insert_mode == "clob": - embedding_arr = json.dumps(embedding) - else: - embedding_arr = array.array("f", embedding) - - query = f""" - SELECT id, - text, - metadata, - vector_distance(embedding, :embedding, - {_get_distance_function(self.distance_strategy)}) as distance - FROM {self.table_name} - ORDER BY distance - FETCH APPROX FIRST {k} ROWS ONLY - """ - - # Execute the query - connection = _get_connection(self.client) - if connection is None: - raise ValueError("Failed to acquire a connection.") - with connection.cursor() as cursor: - cursor.execute(query, embedding=embedding_arr) - results = cursor.fetchall() - - # Filter results if filter is provided - for result in results: - metadata = dict(result[2]) if isinstance(result[2], dict) else {} - - # Apply filtering based on the 'filter' dictionary - if filter: - if all(metadata.get(key) in value for key, value in filter.items()): - doc = Document( - page_content=( - self._get_clob_value(result[1]) - if result[1] is not None - else "" - ), - metadata=metadata, - ) - distance = result[3] - docs_and_scores.append((doc, distance)) - else: - doc = Document( - page_content=( - self._get_clob_value(result[1]) - if result[1] is not None - else "" - ), - metadata=metadata, - ) - distance = result[3] - docs_and_scores.append((doc, distance)) - - return docs_and_scores - - @_handle_exceptions - def similarity_search_by_vector_returning_embeddings( - self, - embedding: List[float], - k: int, - filter: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float, NDArray[np.float32]]]: - embedding_arr: Any - if self.insert_mode == "clob": - embedding_arr = json.dumps(embedding) - else: - embedding_arr = array.array("f", embedding) - - documents = [] - - query = f""" - SELECT id, - text, - metadata, - vector_distance(embedding, :embedding, { - _get_distance_function(self.distance_strategy) - }) as distance, - embedding - FROM {self.table_name} - ORDER BY distance - FETCH APPROX FIRST {k} ROWS ONLY - """ - - # Execute the query - connection = _get_connection(self.client) - if connection is None: - raise ValueError("Failed to acquire a connection.") - with connection.cursor() as cursor: - cursor.execute(query, embedding=embedding_arr) - results = cursor.fetchall() - - for result in results: - page_content_str = self._get_clob_value(result[1]) - metadata = result[2] if isinstance(result[2], dict) else {} - - # Apply filter if provided and matches; otherwise, add all - # documents - if not filter or all( - metadata.get(key) in value for key, value in filter.items() - ): - document = Document( - page_content=page_content_str, metadata=metadata - ) - distance = result[3] - - # Assuming result[4] is already in the correct format; - # adjust if necessary - current_embedding = ( - np.array(result[4], dtype=np.float32) - if result[4] - else np.empty(0, dtype=np.float32) - ) - - documents.append((document, distance, current_embedding)) - return documents - - @_handle_exceptions - def max_marginal_relevance_search_with_score_by_vector( - self, - embedding: List[float], - *, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, Any]] = None, - ) -> List[Tuple[Document, float]]: - """Return docs and their similarity scores selected using the - maximal marginal - relevance. - - Maximal marginal relevance optimizes for similarity to query AND - diversity - among selected documents. - - Args: - self: An instance of the class - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch before filtering to - pass to MMR algorithm. - filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults - to None. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents and similarity scores selected by maximal - marginal - relevance and score for each. - """ - - # Fetch documents and their scores - docs_scores_embeddings = self.similarity_search_by_vector_returning_embeddings( - embedding, fetch_k, filter=filter - ) - # Assuming documents_with_scores is a list of tuples (Document, score) - - # If you need to split documents and scores for processing (e.g., - # for MMR calculation) - documents, scores, embeddings = ( - zip(*docs_scores_embeddings) if docs_scores_embeddings else ([], [], []) - ) - - # Assume maximal_marginal_relevance method accepts embeddings and - # scores, and returns indices of selected docs - mmr_selected_indices = maximal_marginal_relevance( - np.array(embedding, dtype=np.float32), - list(embeddings), - k=k, - lambda_mult=lambda_mult, - ) - - # Filter documents based on MMR-selected indices and map scores - mmr_selected_documents_with_scores = [ - (documents[i], scores[i]) for i in mmr_selected_indices - ] - - return mmr_selected_documents_with_scores - - @_handle_exceptions - def max_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND - diversity - among selected documents. - - Args: - self: An instance of the class - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter: Optional[Dict[str, Any]] - **kwargs: Any - Returns: - List of Documents selected by maximal marginal relevance. - """ - docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( - embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter - ) - return [doc for doc, _ in docs_and_scores] - - @_handle_exceptions - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND - diversity - among selected documents. - - Args: - self: An instance of the class - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter: Optional[Dict[str, Any]] - **kwargs - Returns: - List of Documents selected by maximal marginal relevance. - - `max_marginal_relevance_search` requires that `query` returns matched - embeddings alongside the match documents. - """ - embedding = self._embed_query(query) - documents = self.max_marginal_relevance_search_by_vector( - embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - **kwargs, - ) - return documents - - @_handle_exceptions - def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None: - """Delete by vector IDs. - Args: - self: An instance of the class - ids: List of ids to delete. - **kwargs - """ - - if ids is None: - raise ValueError("No ids provided to delete.") - - # Compute SHA-256 hashes of the ids and truncate them - hashed_ids = [ - hashlib.sha256(_id.encode()).hexdigest()[:16].upper() for _id in ids - ] - - # Constructing the SQL statement with individual placeholders - placeholders = ", ".join([":id" + str(i + 1) for i in range(len(hashed_ids))]) - - ddl = f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})" - - # Preparing bind variables - bind_vars = { - f"id{i}": hashed_id for i, hashed_id in enumerate(hashed_ids, start=1) - } - - connection = _get_connection(self.client) - if connection is None: - raise ValueError("Failed to acquire a connection.") - with connection.cursor() as cursor: - cursor.execute(ddl, bind_vars) - connection.commit() - - @classmethod - @_handle_exceptions - def from_texts( - cls: Type[OracleVS], - texts: Iterable[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ) -> OracleVS: - client: Any = kwargs.get("client", None) - if client is None: - raise ValueError("client parameter is required...") - - params = kwargs.get("params", {}) - - table_name = str(kwargs.get("table_name", "langchain")) - - distance_strategy = cast( - DistanceStrategy, kwargs.get("distance_strategy", None) - ) - if not isinstance(distance_strategy, DistanceStrategy): - raise TypeError( - f"Expected DistanceStrategy got {type(distance_strategy).__name__} " - ) - - query = kwargs.get("query", "What is a Oracle database") - - drop_table_purge(client, table_name) - - vss = cls( - client=client, - embedding_function=embedding, - table_name=table_name, - distance_strategy=distance_strategy, - query=query, - params=params, - ) - vss.add_texts(texts=list(texts), metadatas=metadatas) - return vss diff --git a/libs/community/langchain_community/vectorstores/pinecone.py b/libs/community/langchain_community/vectorstores/pinecone.py deleted file mode 100644 index 7ad40b2c0..000000000 --- a/libs/community/langchain_community/vectorstores/pinecone.py +++ /dev/null @@ -1,488 +0,0 @@ -from __future__ import annotations - -import logging -import os -import uuid -import warnings -from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple, Union - -import numpy as np -from langchain_core._api.deprecation import deprecated -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from langchain_core.utils.iter import batch_iterate -from langchain_core.vectorstores import VectorStore -from packaging import version - -from langchain_community.vectorstores.utils import ( - DistanceStrategy, - maximal_marginal_relevance, -) - -if TYPE_CHECKING: - from pinecone import Index - -logger = logging.getLogger(__name__) - - -def _import_pinecone() -> Any: - try: - import pinecone - except ImportError as e: - raise ImportError( - "Could not import pinecone python package. " - "Please install it with `pip3 install pinecone`." - ) from e - return pinecone - - -def _is_pinecone_v3() -> bool: - pinecone = _import_pinecone() - pinecone_client_version = pinecone.__version__ - return version.parse(pinecone_client_version) >= version.parse("3.0.0.dev") - - -@deprecated( - since="0.0.18", removal="1.0", alternative_import="langchain_pinecone.Pinecone" -) -class Pinecone(VectorStore): - """`Pinecone` vector store. - - To use, you should have the ``pinecone`` python package installed. - - This version of Pinecone is deprecated. Please use `langchain_pinecone.Pinecone` - instead. - """ - - def __init__( - self, - index: Any, - embedding: Union[Embeddings, Callable], - text_key: str, - namespace: Optional[str] = None, - distance_strategy: Optional[DistanceStrategy] = DistanceStrategy.COSINE, - ): - """Initialize with Pinecone client.""" - pinecone = _import_pinecone() - if not isinstance(embedding, Embeddings): - warnings.warn( - "Passing in `embedding` as a Callable is deprecated. Please pass in an" - " Embeddings object instead." - ) - if not isinstance(index, pinecone.Index): - raise ValueError( - f"client should be an instance of pinecone.Index, got {type(index)}" - ) - self._index = index - self._embedding = embedding - self._text_key = text_key - self._namespace = namespace - self.distance_strategy = distance_strategy - - @property - def embeddings(self) -> Optional[Embeddings]: - """Access the query embedding object if available.""" - if isinstance(self._embedding, Embeddings): - return self._embedding - return None - - def _embed_documents(self, texts: Iterable[str]) -> List[List[float]]: - """Embed search docs.""" - if isinstance(self._embedding, Embeddings): - return self._embedding.embed_documents(list(texts)) - return [self._embedding(t) for t in texts] - - def _embed_query(self, text: str) -> List[float]: - """Embed query text.""" - if isinstance(self._embedding, Embeddings): - return self._embedding.embed_query(text) - return self._embedding(text) - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - namespace: Optional[str] = None, - batch_size: int = 32, - embedding_chunk_size: int = 1000, - **kwargs: Any, - ) -> List[str]: - """Run more texts through the embeddings and add to the vectorstore. - - Upsert optimization is done by chunking the embeddings and upserting them. - This is done to avoid memory issues and optimize using HTTP based embeddings. - For OpenAI embeddings, use pool_threads>4 when constructing the pinecone.Index, - embedding_chunk_size>1000 and batch_size~64 for best performance. - Args: - texts: Iterable of strings to add to the vectorstore. - metadatas: Optional list of metadatas associated with the texts. - ids: Optional list of ids to associate with the texts. - namespace: Optional pinecone namespace to add the texts to. - batch_size: Batch size to use when adding the texts to the vectorstore. - embedding_chunk_size: Chunk size to use when embedding the texts. - - Returns: - List of ids from adding the texts into the vectorstore. - - """ - if namespace is None: - namespace = self._namespace - - texts = list(texts) - ids = ids or [str(uuid.uuid4()) for _ in texts] - metadatas = metadatas or [{} for _ in texts] - for metadata, text in zip(metadatas, texts): - metadata[self._text_key] = text - - # For loops to avoid memory issues and optimize when using HTTP based embeddings - # The first loop runs the embeddings, it benefits when using OpenAI embeddings - # The second loops runs the pinecone upsert asynchronously. - for i in range(0, len(texts), embedding_chunk_size): - chunk_texts = texts[i : i + embedding_chunk_size] - chunk_ids = ids[i : i + embedding_chunk_size] - chunk_metadatas = metadatas[i : i + embedding_chunk_size] - embeddings = self._embed_documents(chunk_texts) - async_res = [ - self._index.upsert( - vectors=batch, - namespace=namespace, - async_req=True, - **kwargs, - ) - for batch in batch_iterate( - batch_size, zip(chunk_ids, embeddings, chunk_metadatas) - ) - ] - [res.get() for res in async_res] - - return ids - - def similarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, - namespace: Optional[str] = None, - ) -> List[Tuple[Document, float]]: - """Return pinecone documents most similar to query, along with scores. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Dictionary of argument(s) to filter on metadata - namespace: Namespace to search in. Default will search in '' namespace. - - Returns: - List of Documents most similar to the query and score for each - """ - return self.similarity_search_by_vector_with_score( - self._embed_query(query), k=k, filter=filter, namespace=namespace - ) - - def similarity_search_by_vector_with_score( - self, - embedding: List[float], - *, - k: int = 4, - filter: Optional[dict] = None, - namespace: Optional[str] = None, - ) -> List[Tuple[Document, float]]: - """Return pinecone documents most similar to embedding, along with scores.""" - - if namespace is None: - namespace = self._namespace - docs = [] - results = self._index.query( - vector=[embedding], - top_k=k, - include_metadata=True, - namespace=namespace, - filter=filter, - ) - for res in results["matches"]: - metadata = res["metadata"] - if self._text_key in metadata: - text = metadata.pop(self._text_key) - score = res["score"] - docs.append((Document(page_content=text, metadata=metadata), score)) - else: - logger.warning( - f"Found document with no `{self._text_key}` key. Skipping." - ) - return docs - - def similarity_search( - self, - query: str, - k: int = 4, - filter: Optional[dict] = None, - namespace: Optional[str] = None, - **kwargs: Any, - ) -> List[Document]: - """Return pinecone documents most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Dictionary of argument(s) to filter on metadata - namespace: Namespace to search in. Default will search in '' namespace. - - Returns: - List of Documents most similar to the query and score for each - """ - docs_and_scores = self.similarity_search_with_score( - query, k=k, filter=filter, namespace=namespace, **kwargs - ) - return [doc for doc, _ in docs_and_scores] - - def _select_relevance_score_fn(self) -> Callable[[float], float]: - """ - The 'correct' relevance function - may differ depending on a few things, including: - - the distance / similarity metric used by the VectorStore - - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) - - embedding dimensionality - - etc. - """ - - if self.distance_strategy == DistanceStrategy.COSINE: - return self._cosine_relevance_score_fn - elif self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: - return self._max_inner_product_relevance_score_fn - elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE: - return self._euclidean_relevance_score_fn - else: - raise ValueError( - "Unknown distance strategy, must be cosine, max_inner_product " - "(dot product), or euclidean" - ) - - @staticmethod - def _cosine_relevance_score_fn(score: float) -> float: - """Pinecone returns cosine similarity scores between [-1,1]""" - return (score + 1) / 2 - - def max_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[dict] = None, - namespace: Optional[str] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance. - """ - if namespace is None: - namespace = self._namespace - results = self._index.query( - vector=[embedding], - top_k=fetch_k, - include_values=True, - include_metadata=True, - namespace=namespace, - filter=filter, - ) - mmr_selected = maximal_marginal_relevance( - np.array([embedding], dtype=np.float32), - [item["values"] for item in results["matches"]], - k=k, - lambda_mult=lambda_mult, - ) - selected = [results["matches"][i]["metadata"] for i in mmr_selected] - return [ - Document(page_content=metadata.pop((self._text_key)), metadata=metadata) - for metadata in selected - ] - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[dict] = None, - namespace: Optional[str] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance. - """ - embedding = self._embed_query(query) - return self.max_marginal_relevance_search_by_vector( - embedding, k, fetch_k, lambda_mult, filter, namespace - ) - - @classmethod - def get_pinecone_index( - cls, - index_name: Optional[str], - pool_threads: int = 4, - ) -> Index: - """Return a Pinecone Index instance. - - Args: - index_name: Name of the index to use. - pool_threads: Number of threads to use for index upsert. - Returns: - Pinecone Index instance.""" - - pinecone = _import_pinecone() - - if _is_pinecone_v3(): - pinecone_instance = pinecone.Pinecone( - api_key=os.environ.get("PINECONE_API_KEY"), pool_threads=pool_threads - ) - indexes = pinecone_instance.list_indexes() - index_names = [i.name for i in indexes.index_list["indexes"]] - else: - index_names = pinecone.list_indexes() - - if index_name in index_names: - index = ( - pinecone_instance.Index(index_name) - if _is_pinecone_v3() - else pinecone.Index(index_name, pool_threads=pool_threads) - ) - elif len(index_names) == 0: - raise ValueError( - "No active indexes found in your Pinecone project, " - "are you sure you're using the right Pinecone API key and Environment? " - "Please double check your Pinecone dashboard." - ) - else: - raise ValueError( - f"Index '{index_name}' not found in your Pinecone project. " - f"Did you mean one of the following indexes: {', '.join(index_names)}" - ) - return index - - @classmethod - def from_texts( - cls, - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - batch_size: int = 32, - text_key: str = "text", - namespace: Optional[str] = None, - index_name: Optional[str] = None, - upsert_kwargs: Optional[dict] = None, - pool_threads: int = 4, - embeddings_chunk_size: int = 1000, - **kwargs: Any, - ) -> Pinecone: - """ - DEPRECATED: use langchain_pinecone.PineconeVectorStore.from_texts instead: - Construct Pinecone wrapper from raw documents. - - This is a user friendly interface that: - 1. Embeds documents. - 2. Adds the documents to a provided Pinecone index - - This is intended to be a quick way to get started. - - The `pool_threads` affects the speed of the upsert operations. - - Example: - .. code-block:: python - - from langchain_pinecone import PineconeVectorStore - from langchain_openai import OpenAIEmbeddings - - embeddings = OpenAIEmbeddings() - index_name = "my-index" - namespace = "my-namespace" - vectorstore = Pinecone( - index_name=index_name, - embedding=embedding, - namespace=namespace, - ) - """ - pinecone_index = cls.get_pinecone_index(index_name, pool_threads) - pinecone = cls(pinecone_index, embedding, text_key, namespace, **kwargs) - - pinecone.add_texts( - texts, - metadatas=metadatas, - ids=ids, - namespace=namespace, - batch_size=batch_size, - embedding_chunk_size=embeddings_chunk_size, - **(upsert_kwargs or {}), - ) - return pinecone - - @classmethod - def from_existing_index( - cls, - index_name: str, - embedding: Embeddings, - text_key: str = "text", - namespace: Optional[str] = None, - pool_threads: int = 4, - ) -> Pinecone: - """Load pinecone vectorstore from index name.""" - pinecone_index = cls.get_pinecone_index(index_name, pool_threads) - return cls(pinecone_index, embedding, text_key, namespace) - - def delete( - self, - ids: Optional[List[str]] = None, - delete_all: Optional[bool] = None, - namespace: Optional[str] = None, - filter: Optional[dict] = None, - **kwargs: Any, - ) -> None: - """Delete by vector IDs or filter. - Args: - ids: List of ids to delete. - filter: Dictionary of conditions to filter vectors to delete. - """ - - if namespace is None: - namespace = self._namespace - - if delete_all: - self._index.delete(delete_all=True, namespace=namespace, **kwargs) - elif ids is not None: - chunk_size = 1000 - for i in range(0, len(ids), chunk_size): - chunk = ids[i : i + chunk_size] - self._index.delete(ids=chunk, namespace=namespace, **kwargs) - elif filter is not None: - self._index.delete(filter=filter, namespace=namespace, **kwargs) - else: - raise ValueError("Either ids, delete_all, or filter must be provided.") - - return None diff --git a/libs/community/langchain_community/vectorstores/qdrant.py b/libs/community/langchain_community/vectorstores/qdrant.py deleted file mode 100644 index b4a9c60b2..000000000 --- a/libs/community/langchain_community/vectorstores/qdrant.py +++ /dev/null @@ -1,2279 +0,0 @@ -from __future__ import annotations - -import functools -import uuid -import warnings -from itertools import islice -from operator import itemgetter -from typing import ( - TYPE_CHECKING, - Any, - AsyncGenerator, - Callable, - Dict, - Generator, - Iterable, - List, - Optional, - Sequence, - Tuple, - Type, - Union, -) - -import numpy as np -from langchain_core._api.deprecation import deprecated -from langchain_core.embeddings import Embeddings -from langchain_core.runnables.config import run_in_executor -from langchain_core.vectorstores import VectorStore - -from langchain_community.docstore.document import Document -from langchain_community.vectorstores.utils import maximal_marginal_relevance - -if TYPE_CHECKING: - from qdrant_client import grpc # noqa - from qdrant_client.conversions import common_types - from qdrant_client.http import models as rest - - DictFilter = Dict[str, Union[str, int, bool, dict, list]] - MetadataFilter = Union[DictFilter, common_types.Filter] - - -class QdrantException(Exception): - """`Qdrant` related exceptions.""" - - -def sync_call_fallback(method: Callable) -> Callable: - """ - Decorator to call the synchronous method of the class if the async method is not - implemented. This decorator might be only used for the methods that are defined - as async in the class. - """ - - @functools.wraps(method) - async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - try: - return await method(self, *args, **kwargs) - except NotImplementedError: - # If the async method is not implemented, call the synchronous method - # by removing the first letter from the method name. For example, - # if the async method is called ``aaad_texts``, the synchronous method - # will be called ``aad_texts``. - return await run_in_executor( - None, getattr(self, method.__name__[1:]), *args, **kwargs - ) - - return wrapper - - -@deprecated(since="0.0.37", removal="1.0", alternative_import="langchain_qdrant.Qdrant") -class Qdrant(VectorStore): - """`Qdrant` vector store. - - To use you should have the ``qdrant-client`` package installed. - - Example: - .. code-block:: python - - from qdrant_client import QdrantClient - from langchain_community.vectorstores import Qdrant - - client = QdrantClient() - collection_name = "MyCollection" - qdrant = Qdrant(client, collection_name, embedding_function) - """ - - CONTENT_KEY: str = "page_content" - METADATA_KEY: str = "metadata" - VECTOR_NAME = None - - def __init__( - self, - client: Any, - collection_name: str, - embeddings: Optional[Embeddings] = None, - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, - distance_strategy: str = "COSINE", - vector_name: Optional[str] = VECTOR_NAME, - async_client: Optional[Any] = None, - embedding_function: Optional[Callable] = None, # deprecated - ): - """Initialize with necessary components.""" - try: - import qdrant_client - except ImportError: - raise ImportError( - "Could not import qdrant-client python package. " - "Please install it with `pip install qdrant-client`." - ) - - if not isinstance(client, qdrant_client.QdrantClient): - raise ValueError( - f"client should be an instance of qdrant_client.QdrantClient, " - f"got {type(client)}" - ) - - if async_client is not None and not isinstance( - async_client, qdrant_client.AsyncQdrantClient - ): - raise ValueError( - f"async_client should be an instance of qdrant_client.AsyncQdrantClient" - f"got {type(async_client)}" - ) - - if embeddings is None and embedding_function is None: - raise ValueError( - "`embeddings` value can't be None. Pass `Embeddings` instance." - ) - - if embeddings is not None and embedding_function is not None: - raise ValueError( - "Both `embeddings` and `embedding_function` are passed. " - "Use `embeddings` only." - ) - - self._embeddings = embeddings - self._embeddings_function = embedding_function - self.client: qdrant_client.QdrantClient = client - self.async_client: Optional[qdrant_client.AsyncQdrantClient] = async_client - self.collection_name = collection_name - self.content_payload_key = content_payload_key or self.CONTENT_KEY - self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY - self.vector_name = vector_name or self.VECTOR_NAME - - if embedding_function is not None: - warnings.warn( - "Using `embedding_function` is deprecated. " - "Pass `Embeddings` instance to `embeddings` instead." - ) - - if not isinstance(embeddings, Embeddings): - warnings.warn( - "`embeddings` should be an instance of `Embeddings`." - "Using `embeddings` as `embedding_function` which is deprecated" - ) - self._embeddings_function = embeddings - self._embeddings = None - - self.distance_strategy = distance_strategy.upper() - - @property - def embeddings(self) -> Optional[Embeddings]: - return self._embeddings - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[Sequence[str]] = None, - batch_size: int = 64, - **kwargs: Any, - ) -> List[str]: - """Run more texts through the embeddings and add to the vectorstore. - - Args: - texts: Iterable of strings to add to the vectorstore. - metadatas: Optional list of metadatas associated with the texts. - ids: - Optional list of ids to associate with the texts. Ids have to be - uuid-like strings. - batch_size: - How many vectors upload per-request. - Default: 64 - - Returns: - List of ids from adding the texts into the vectorstore. - """ - added_ids = [] - for batch_ids, points in self._generate_rest_batches( - texts, metadatas, ids, batch_size - ): - self.client.upsert( - collection_name=self.collection_name, points=points, **kwargs - ) - added_ids.extend(batch_ids) - - return added_ids - - @sync_call_fallback - async def aadd_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[Sequence[str]] = None, - batch_size: int = 64, - **kwargs: Any, - ) -> List[str]: - """Run more texts through the embeddings and add to the vectorstore. - - Args: - texts: Iterable of strings to add to the vectorstore. - metadatas: Optional list of metadatas associated with the texts. - ids: - Optional list of ids to associate with the texts. Ids have to be - uuid-like strings. - batch_size: - How many vectors upload per-request. - Default: 64 - - Returns: - List of ids from adding the texts into the vectorstore. - """ - from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal - - if self.async_client is None or isinstance( - self.async_client._client, AsyncQdrantLocal - ): - raise NotImplementedError( - "QdrantLocal cannot interoperate with sync and async clients" - ) - - added_ids = [] - async for batch_ids, points in self._agenerate_rest_batches( - texts, metadatas, ids, batch_size - ): - await self.async_client.upsert( - collection_name=self.collection_name, points=points, **kwargs - ) - added_ids.extend(batch_ids) - - return added_ids - - def similarity_search( - self, - query: str, - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - **kwargs: - Any other named arguments to pass through to QdrantClient.search() - - Returns: - List of Documents most similar to the query. - """ - results = self.similarity_search_with_score( - query, - k, - filter=filter, - search_params=search_params, - offset=offset, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - return list(map(itemgetter(0), results)) - - @sync_call_fallback - async def asimilarity_search( - self, - query: str, - k: int = 4, - filter: Optional[MetadataFilter] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs most similar to query. - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - Returns: - List of Documents most similar to the query. - """ - results = await self.asimilarity_search_with_score(query, k, filter, **kwargs) - return list(map(itemgetter(0), results)) - - def similarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - **kwargs: - Any other named arguments to pass through to QdrantClient.search() - - Returns: - List of documents most similar to the query text and distance for each. - """ - return self.similarity_search_with_score_by_vector( - self._embed_query(query), - k, - filter=filter, - search_params=search_params, - offset=offset, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - - @sync_call_fallback - async def asimilarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - **kwargs: - Any other named arguments to pass through to - AsyncQdrantClient.Search(). - - Returns: - List of documents most similar to the query text and distance for each. - """ - query_embedding = await self._aembed_query(query) - return await self.asimilarity_search_with_score_by_vector( - query_embedding, - k, - filter=filter, - search_params=search_params, - offset=offset, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - - def similarity_search_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding vector to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - **kwargs: - Any other named arguments to pass through to QdrantClient.search() - - Returns: - List of Documents most similar to the query. - """ - results = self.similarity_search_with_score_by_vector( - embedding, - k, - filter=filter, - search_params=search_params, - offset=offset, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - return list(map(itemgetter(0), results)) - - @sync_call_fallback - async def asimilarity_search_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding vector to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - **kwargs: - Any other named arguments to pass through to - AsyncQdrantClient.Search(). - - Returns: - List of Documents most similar to the query. - """ - results = await self.asimilarity_search_with_score_by_vector( - embedding, - k, - filter=filter, - search_params=search_params, - offset=offset, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - return list(map(itemgetter(0), results)) - - def similarity_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding vector to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - **kwargs: - Any other named arguments to pass through to QdrantClient.search() - - Returns: - List of documents most similar to the query text and distance for each. - """ - if filter is not None and isinstance(filter, dict): - warnings.warn( - "Using dict as a `filter` is deprecated. Please use qdrant-client " - "filters directly: " - "https://qdrant.tech/documentation/concepts/filtering/", - DeprecationWarning, - ) - qdrant_filter = self._qdrant_filter_from_dict(filter) - else: - qdrant_filter = filter - - query_vector = embedding - if self.vector_name is not None: - query_vector = (self.vector_name, embedding) # type: ignore[assignment] - - results = self.client.search( - collection_name=self.collection_name, - query_vector=query_vector, - query_filter=qdrant_filter, - search_params=search_params, - limit=k, - offset=offset, - with_payload=True, - with_vectors=False, # Langchain does not expect vectors to be returned - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - return [ - ( - self._document_from_scored_point( - result, - self.collection_name, - self.content_payload_key, - self.metadata_payload_key, - ), - result.score, - ) - for result in results - ] - - @sync_call_fallback - async def asimilarity_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding vector to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - offset: - Offset of the first result to return. - May be used to paginate results. - Note: large offset values may cause performance issues. - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - **kwargs: - Any other named arguments to pass through to - AsyncQdrantClient.Search(). - - Returns: - List of documents most similar to the query text and distance for each. - """ - from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal - - if self.async_client is None or isinstance( - self.async_client._client, AsyncQdrantLocal - ): - raise NotImplementedError( - "QdrantLocal cannot interoperate with sync and async clients" - ) - if filter is not None and isinstance(filter, dict): - warnings.warn( - "Using dict as a `filter` is deprecated. Please use qdrant-client " - "filters directly: " - "https://qdrant.tech/documentation/concepts/filtering/", - DeprecationWarning, - ) - qdrant_filter = self._qdrant_filter_from_dict(filter) - else: - qdrant_filter = filter - - query_vector = embedding - if self.vector_name is not None: - query_vector = (self.vector_name, embedding) # type: ignore[assignment] - - results = await self.async_client.search( - collection_name=self.collection_name, - query_vector=query_vector, - query_filter=qdrant_filter, - search_params=search_params, - limit=k, - offset=offset, - with_payload=True, - with_vectors=False, # Langchain does not expect vectors to be returned - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - return [ - ( - self._document_from_scored_point( - result, - self.collection_name, - self.content_payload_key, - self.metadata_payload_key, - ), - result.score, - ) - for result in results - ] - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - **kwargs: - Any other named arguments to pass through to QdrantClient.search() - Returns: - List of Documents selected by maximal marginal relevance. - """ - query_embedding = self._embed_query(query) - return self.max_marginal_relevance_search_by_vector( - query_embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - search_params=search_params, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - - @sync_call_fallback - async def amax_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - **kwargs: - Any other named arguments to pass through to - AsyncQdrantClient.Search(). - Returns: - List of Documents selected by maximal marginal relevance. - """ - query_embedding = await self._aembed_query(query) - return await self.amax_marginal_relevance_search_by_vector( - query_embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - search_params=search_params, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - - def max_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - **kwargs: - Any other named arguments to pass through to QdrantClient.search() - Returns: - List of Documents selected by maximal marginal relevance. - """ - results = self.max_marginal_relevance_search_with_score_by_vector( - embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - search_params=search_params, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - return list(map(itemgetter(0), results)) - - @sync_call_fallback - async def amax_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - **kwargs: - Any other named arguments to pass through to - AsyncQdrantClient.Search(). - Returns: - List of Documents selected by maximal marginal relevance and distance for - each. - """ - results = await self.amax_marginal_relevance_search_with_score_by_vector( - embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - search_params=search_params, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - return list(map(itemgetter(0), results)) - - def max_marginal_relevance_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs selected using the maximal marginal relevance. - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter: Filter by metadata. Defaults to None. - search_params: Additional search params - score_threshold: - Define a minimal score threshold for the result. - If defined, less similar results will not be returned. - Score of the returned result might be higher or smaller than the - threshold depending on the Distance function used. - E.g. for cosine similarity only higher scores will be returned. - consistency: - Read consistency of the search. Defines how many replicas should be - queried before returning the result. - Values: - - int - number of replicas to query, values should present in all - queried replicas - - 'majority' - query all replicas, but return values present in the - majority of replicas - - 'quorum' - query the majority of replicas, return values present in - all of them - - 'all' - query all replicas, and return values present in all replicas - **kwargs: - Any other named arguments to pass through to QdrantClient.search() - Returns: - List of Documents selected by maximal marginal relevance and distance for - each. - """ - query_vector = embedding - if self.vector_name is not None: - query_vector = (self.vector_name, query_vector) # type: ignore[assignment] - - results = self.client.search( - collection_name=self.collection_name, - query_vector=query_vector, - query_filter=filter, - search_params=search_params, - limit=fetch_k, - with_payload=True, - with_vectors=True, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - embeddings = [ - result.vector.get(self.vector_name) - if self.vector_name is not None - else result.vector - for result in results - ] - mmr_selected = maximal_marginal_relevance( - np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult - ) - return [ - ( - self._document_from_scored_point( - results[i], - self.collection_name, - self.content_payload_key, - self.metadata_payload_key, - ), - results[i].score, - ) - for i in mmr_selected - ] - - @sync_call_fallback - async def amax_marginal_relevance_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs selected using the maximal marginal relevance. - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - Defaults to 20. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance and distance for - each. - """ - from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal - - if self.async_client is None or isinstance( - self.async_client._client, AsyncQdrantLocal - ): - raise NotImplementedError( - "QdrantLocal cannot interoperate with sync and async clients" - ) - query_vector = embedding - if self.vector_name is not None: - query_vector = (self.vector_name, query_vector) # type: ignore[assignment] - - results = await self.async_client.search( - collection_name=self.collection_name, - query_vector=query_vector, - query_filter=filter, - search_params=search_params, - limit=fetch_k, - with_payload=True, - with_vectors=True, - score_threshold=score_threshold, - consistency=consistency, - **kwargs, - ) - embeddings = [ - result.vector.get(self.vector_name) - if self.vector_name is not None - else result.vector - for result in results - ] - mmr_selected = maximal_marginal_relevance( - np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult - ) - return [ - ( - self._document_from_scored_point( - results[i], - self.collection_name, - self.content_payload_key, - self.metadata_payload_key, - ), - results[i].score, - ) - for i in mmr_selected - ] - - def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: - """Delete by vector ID or other criteria. - - Args: - ids: List of ids to delete. - **kwargs: Other keyword arguments that subclasses might use. - - Returns: - True if deletion is successful, False otherwise. - """ - from qdrant_client.http import models as rest - - result = self.client.delete( - collection_name=self.collection_name, - points_selector=ids, - ) - return result.status == rest.UpdateStatus.COMPLETED - - @sync_call_fallback - async def adelete( - self, ids: Optional[List[str]] = None, **kwargs: Any - ) -> Optional[bool]: - """Delete by vector ID or other criteria. - - Args: - ids: List of ids to delete. - **kwargs: Other keyword arguments that subclasses might use. - - Returns: - True if deletion is successful, False otherwise. - """ - from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal - - if self.async_client is None or isinstance( - self.async_client._client, AsyncQdrantLocal - ): - raise NotImplementedError( - "QdrantLocal cannot interoperate with sync and async clients" - ) - - from qdrant_client.http import models as rest - - result = await self.async_client.delete( - collection_name=self.collection_name, - points_selector=ids, - ) - - return result.status == rest.UpdateStatus.COMPLETED - - @classmethod - def from_texts( - cls: Type[Qdrant], - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - ids: Optional[Sequence[str]] = None, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, - grpc_port: int = 6334, - prefer_grpc: bool = False, - https: Optional[bool] = None, - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[float] = None, - host: Optional[str] = None, - path: Optional[str] = None, - collection_name: Optional[str] = None, - distance_func: str = "Cosine", - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, - vector_name: Optional[str] = VECTOR_NAME, - batch_size: int = 64, - shard_number: Optional[int] = None, - replication_factor: Optional[int] = None, - write_consistency_factor: Optional[int] = None, - on_disk_payload: Optional[bool] = None, - hnsw_config: Optional[common_types.HnswConfigDiff] = None, - optimizers_config: Optional[common_types.OptimizersConfigDiff] = None, - wal_config: Optional[common_types.WalConfigDiff] = None, - quantization_config: Optional[common_types.QuantizationConfig] = None, - init_from: Optional[common_types.InitFrom] = None, - on_disk: Optional[bool] = None, - force_recreate: bool = False, - **kwargs: Any, - ) -> Qdrant: - """Construct Qdrant wrapper from a list of texts. - - Args: - texts: A list of texts to be indexed in Qdrant. - embedding: A subclass of `Embeddings`, responsible for text vectorization. - metadatas: - An optional list of metadata. If provided it has to be of the same - length as a list of texts. - ids: - Optional list of ids to associate with the texts. Ids have to be - uuid-like strings. - location: - If `:memory:` - use in-memory Qdrant instance. - If `str` - use it as a `url` parameter. - If `None` - fallback to relying on `host` and `port` parameters. - url: either host or str of "Optional[scheme], host, Optional[port], - Optional[prefix]". Default: `None` - port: Port of the REST API interface. Default: 6333 - grpc_port: Port of the gRPC interface. Default: 6334 - prefer_grpc: - If true - use gPRC interface whenever possible in custom methods. - Default: False - https: If true - use HTTPS(SSL) protocol. Default: None - api_key: API key for authentication in Qdrant Cloud. Default: None - prefix: - If not None - add prefix to the REST URL path. - Example: service/v1 will result in - http://localhost:6333/service/v1/{qdrant-endpoint} for REST API. - Default: None - timeout: - Timeout for REST and gRPC API requests. - Default: 5.0 seconds for REST and unlimited for gRPC - host: - Host name of Qdrant service. If url and host are None, set to - 'localhost'. Default: None - path: - Path in which the vectors will be stored while using local mode. - Default: None - collection_name: - Name of the Qdrant collection to be used. If not provided, - it will be created randomly. Default: None - distance_func: - Distance function. One of: "Cosine" / "Euclid" / "Dot". - Default: "Cosine" - content_payload_key: - A payload key used to store the content of the document. - Default: "page_content" - metadata_payload_key: - A payload key used to store the metadata of the document. - Default: "metadata" - vector_name: - Name of the vector to be used internally in Qdrant. - Default: None - batch_size: - How many vectors upload per-request. - Default: 64 - shard_number: Number of shards in collection. Default is 1, minimum is 1. - replication_factor: - Replication factor for collection. Default is 1, minimum is 1. - Defines how many copies of each shard will be created. - Have effect only in distributed mode. - write_consistency_factor: - Write consistency factor for collection. Default is 1, minimum is 1. - Defines how many replicas should apply the operation for us to consider - it successful. Increasing this number will make the collection more - resilient to inconsistencies, but will also make it fail if not enough - replicas are available. - Does not have any performance impact. - Have effect only in distributed mode. - on_disk_payload: - If true - point`s payload will not be stored in memory. - It will be read from the disk every time it is requested. - This setting saves RAM by (slightly) increasing the response time. - Note: those payload values that are involved in filtering and are - indexed - remain in RAM. - hnsw_config: Params for HNSW index - optimizers_config: Params for optimizer - wal_config: Params for Write-Ahead-Log - quantization_config: - Params for quantization, if None - quantization will be disabled - init_from: - Use data stored in another collection to initialize this collection - force_recreate: - Force recreating the collection - **kwargs: - Additional arguments passed directly into REST client initialization - - This is a user-friendly interface that: - 1. Creates embeddings, one for each text - 2. Initializes the Qdrant database as an in-memory docstore by default - (and overridable to a remote docstore) - 3. Adds the text embeddings to the Qdrant database - - This is intended to be a quick way to get started. - - Example: - .. code-block:: python - - from langchain_community.vectorstores import Qdrant - from langchain_community.embeddings import OpenAIEmbeddings - embeddings = OpenAIEmbeddings() - qdrant = Qdrant.from_texts(texts, embeddings, "localhost") - """ - qdrant = cls.construct_instance( - texts, - embedding, - location, - url, - port, - grpc_port, - prefer_grpc, - https, - api_key, - prefix, - timeout, - host, - path, - collection_name, - distance_func, - content_payload_key, - metadata_payload_key, - vector_name, - shard_number, - replication_factor, - write_consistency_factor, - on_disk_payload, - hnsw_config, - optimizers_config, - wal_config, - quantization_config, - init_from, - on_disk, - force_recreate, - **kwargs, - ) - qdrant.add_texts(texts, metadatas, ids, batch_size) - return qdrant - - @classmethod - def from_existing_collection( - cls: Type[Qdrant], - embedding: Embeddings, - path: str, - collection_name: str, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, - grpc_port: int = 6334, - prefer_grpc: bool = False, - https: Optional[bool] = None, - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[float] = None, - host: Optional[str] = None, - **kwargs: Any, - ) -> Qdrant: - """ - Get instance of an existing Qdrant collection. - This method will return the instance of the store without inserting any new - embeddings - """ - client, async_client = cls._generate_clients( - location=location, - url=url, - port=port, - grpc_port=grpc_port, - prefer_grpc=prefer_grpc, - https=https, - api_key=api_key, - prefix=prefix, - timeout=timeout, - host=host, - path=path, - **kwargs, - ) - return cls( - client=client, - async_client=async_client, - collection_name=collection_name, - embeddings=embedding, - **kwargs, - ) - - @classmethod - @sync_call_fallback - async def afrom_texts( - cls: Type[Qdrant], - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - ids: Optional[Sequence[str]] = None, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, - grpc_port: int = 6334, - prefer_grpc: bool = False, - https: Optional[bool] = None, - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[float] = None, - host: Optional[str] = None, - path: Optional[str] = None, - collection_name: Optional[str] = None, - distance_func: str = "Cosine", - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, - vector_name: Optional[str] = VECTOR_NAME, - batch_size: int = 64, - shard_number: Optional[int] = None, - replication_factor: Optional[int] = None, - write_consistency_factor: Optional[int] = None, - on_disk_payload: Optional[bool] = None, - hnsw_config: Optional[common_types.HnswConfigDiff] = None, - optimizers_config: Optional[common_types.OptimizersConfigDiff] = None, - wal_config: Optional[common_types.WalConfigDiff] = None, - quantization_config: Optional[common_types.QuantizationConfig] = None, - init_from: Optional[common_types.InitFrom] = None, - on_disk: Optional[bool] = None, - force_recreate: bool = False, - **kwargs: Any, - ) -> Qdrant: - """Construct Qdrant wrapper from a list of texts. - - Args: - texts: A list of texts to be indexed in Qdrant. - embedding: A subclass of `Embeddings`, responsible for text vectorization. - metadatas: - An optional list of metadata. If provided it has to be of the same - length as a list of texts. - ids: - Optional list of ids to associate with the texts. Ids have to be - uuid-like strings. - location: - If `:memory:` - use in-memory Qdrant instance. - If `str` - use it as a `url` parameter. - If `None` - fallback to relying on `host` and `port` parameters. - url: either host or str of "Optional[scheme], host, Optional[port], - Optional[prefix]". Default: `None` - port: Port of the REST API interface. Default: 6333 - grpc_port: Port of the gRPC interface. Default: 6334 - prefer_grpc: - If true - use gPRC interface whenever possible in custom methods. - Default: False - https: If true - use HTTPS(SSL) protocol. Default: None - api_key: API key for authentication in Qdrant Cloud. Default: None - prefix: - If not None - add prefix to the REST URL path. - Example: service/v1 will result in - http://localhost:6333/service/v1/{qdrant-endpoint} for REST API. - Default: None - timeout: - Timeout for REST and gRPC API requests. - Default: 5.0 seconds for REST and unlimited for gRPC - host: - Host name of Qdrant service. If url and host are None, set to - 'localhost'. Default: None - path: - Path in which the vectors will be stored while using local mode. - Default: None - collection_name: - Name of the Qdrant collection to be used. If not provided, - it will be created randomly. Default: None - distance_func: - Distance function. One of: "Cosine" / "Euclid" / "Dot". - Default: "Cosine" - content_payload_key: - A payload key used to store the content of the document. - Default: "page_content" - metadata_payload_key: - A payload key used to store the metadata of the document. - Default: "metadata" - vector_name: - Name of the vector to be used internally in Qdrant. - Default: None - batch_size: - How many vectors upload per-request. - Default: 64 - shard_number: Number of shards in collection. Default is 1, minimum is 1. - replication_factor: - Replication factor for collection. Default is 1, minimum is 1. - Defines how many copies of each shard will be created. - Have effect only in distributed mode. - write_consistency_factor: - Write consistency factor for collection. Default is 1, minimum is 1. - Defines how many replicas should apply the operation for us to consider - it successful. Increasing this number will make the collection more - resilient to inconsistencies, but will also make it fail if not enough - replicas are available. - Does not have any performance impact. - Have effect only in distributed mode. - on_disk_payload: - If true - point`s payload will not be stored in memory. - It will be read from the disk every time it is requested. - This setting saves RAM by (slightly) increasing the response time. - Note: those payload values that are involved in filtering and are - indexed - remain in RAM. - hnsw_config: Params for HNSW index - optimizers_config: Params for optimizer - wal_config: Params for Write-Ahead-Log - quantization_config: - Params for quantization, if None - quantization will be disabled - init_from: - Use data stored in another collection to initialize this collection - force_recreate: - Force recreating the collection - **kwargs: - Additional arguments passed directly into REST client initialization - - This is a user-friendly interface that: - 1. Creates embeddings, one for each text - 2. Initializes the Qdrant database as an in-memory docstore by default - (and overridable to a remote docstore) - 3. Adds the text embeddings to the Qdrant database - - This is intended to be a quick way to get started. - - Example: - .. code-block:: python - - from langchain_community.vectorstores import Qdrant - from langchain_community.embeddings import OpenAIEmbeddings - embeddings = OpenAIEmbeddings() - qdrant = await Qdrant.afrom_texts(texts, embeddings, "localhost") - """ - qdrant = await cls.aconstruct_instance( - texts, - embedding, - location, - url, - port, - grpc_port, - prefer_grpc, - https, - api_key, - prefix, - timeout, - host, - path, - collection_name, - distance_func, - content_payload_key, - metadata_payload_key, - vector_name, - shard_number, - replication_factor, - write_consistency_factor, - on_disk_payload, - hnsw_config, - optimizers_config, - wal_config, - quantization_config, - init_from, - on_disk, - force_recreate, - **kwargs, - ) - await qdrant.aadd_texts(texts, metadatas, ids, batch_size) - return qdrant - - @classmethod - def construct_instance( - cls: Type[Qdrant], - texts: List[str], - embedding: Embeddings, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, - grpc_port: int = 6334, - prefer_grpc: bool = False, - https: Optional[bool] = None, - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[float] = None, - host: Optional[str] = None, - path: Optional[str] = None, - collection_name: Optional[str] = None, - distance_func: str = "Cosine", - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, - vector_name: Optional[str] = VECTOR_NAME, - shard_number: Optional[int] = None, - replication_factor: Optional[int] = None, - write_consistency_factor: Optional[int] = None, - on_disk_payload: Optional[bool] = None, - hnsw_config: Optional[common_types.HnswConfigDiff] = None, - optimizers_config: Optional[common_types.OptimizersConfigDiff] = None, - wal_config: Optional[common_types.WalConfigDiff] = None, - quantization_config: Optional[common_types.QuantizationConfig] = None, - init_from: Optional[common_types.InitFrom] = None, - on_disk: Optional[bool] = None, - force_recreate: bool = False, - **kwargs: Any, - ) -> Qdrant: - try: - import qdrant_client # noqa - except ImportError: - raise ImportError( - "Could not import qdrant-client python package. " - "Please install it with `pip install qdrant-client`." - ) - from grpc import RpcError - from qdrant_client.http import models as rest - from qdrant_client.http.exceptions import UnexpectedResponse - - # Just do a single quick embedding to get vector size - partial_embeddings = embedding.embed_documents(texts[:1]) - vector_size = len(partial_embeddings[0]) - collection_name = collection_name or uuid.uuid4().hex - distance_func = distance_func.upper() - client, async_client = cls._generate_clients( - location=location, - url=url, - port=port, - grpc_port=grpc_port, - prefer_grpc=prefer_grpc, - https=https, - api_key=api_key, - prefix=prefix, - timeout=timeout, - host=host, - path=path, - **kwargs, - ) - try: - # Skip any validation in case of forced collection recreate. - if force_recreate: - raise ValueError - - # Get the vector configuration of the existing collection and vector, if it - # was specified. If the old configuration does not match the current one, - # an exception is being thrown. - collection_info = client.get_collection(collection_name=collection_name) - current_vector_config = collection_info.config.params.vectors - if isinstance(current_vector_config, dict) and vector_name is not None: - if vector_name not in current_vector_config: - raise QdrantException( - f"Existing Qdrant collection {collection_name} does not " - f"contain vector named {vector_name}. Did you mean one of the " - f"existing vectors: {', '.join(current_vector_config.keys())}? " - f"If you want to recreate the collection, set `force_recreate` " - f"parameter to `True`." - ) - current_vector_config = current_vector_config.get(vector_name) - elif isinstance(current_vector_config, dict) and vector_name is None: - raise QdrantException( - f"Existing Qdrant collection {collection_name} uses named vectors. " - f"If you want to reuse it, please set `vector_name` to any of the " - f"existing named vectors: " - f"{', '.join(current_vector_config.keys())}." - f"If you want to recreate the collection, set `force_recreate` " - f"parameter to `True`." - ) - elif ( - not isinstance(current_vector_config, dict) and vector_name is not None - ): - raise QdrantException( - f"Existing Qdrant collection {collection_name} doesn't use named " - f"vectors. If you want to reuse it, please set `vector_name` to " - f"`None`. If you want to recreate the collection, set " - f"`force_recreate` parameter to `True`." - ) - - # Check if the vector configuration has the same dimensionality. - if current_vector_config.size != vector_size: - raise QdrantException( - f"Existing Qdrant collection is configured for vectors with " - f"{current_vector_config.size} " - f"dimensions. Selected embeddings are {vector_size}-dimensional. " - f"If you want to recreate the collection, set `force_recreate` " - f"parameter to `True`." - ) - - current_distance_func = current_vector_config.distance.name.upper() - if current_distance_func != distance_func: - raise QdrantException( - f"Existing Qdrant collection is configured for " - f"{current_distance_func} similarity, but requested " - f"{distance_func}. Please set `distance_func` parameter to " - f"`{current_distance_func}` if you want to reuse it. " - f"If you want to recreate the collection, set `force_recreate` " - f"parameter to `True`." - ) - except (UnexpectedResponse, RpcError, ValueError): - vectors_config = rest.VectorParams( - size=vector_size, - distance=rest.Distance[distance_func], - on_disk=on_disk, - ) - - # If vector name was provided, we're going to use the named vectors feature - # with just a single vector. - if vector_name is not None: - vectors_config = { - vector_name: vectors_config, - } - - client.recreate_collection( - collection_name=collection_name, - vectors_config=vectors_config, - shard_number=shard_number, - replication_factor=replication_factor, - write_consistency_factor=write_consistency_factor, - on_disk_payload=on_disk_payload, - hnsw_config=hnsw_config, - optimizers_config=optimizers_config, - wal_config=wal_config, - quantization_config=quantization_config, - init_from=init_from, - timeout=timeout, - ) - qdrant = cls( - client=client, - collection_name=collection_name, - embeddings=embedding, - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - distance_strategy=distance_func, - vector_name=vector_name, - async_client=async_client, - ) - return qdrant - - @classmethod - async def aconstruct_instance( - cls: Type[Qdrant], - texts: List[str], - embedding: Embeddings, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, - grpc_port: int = 6334, - prefer_grpc: bool = False, - https: Optional[bool] = None, - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[float] = None, - host: Optional[str] = None, - path: Optional[str] = None, - collection_name: Optional[str] = None, - distance_func: str = "Cosine", - content_payload_key: str = CONTENT_KEY, - metadata_payload_key: str = METADATA_KEY, - vector_name: Optional[str] = VECTOR_NAME, - shard_number: Optional[int] = None, - replication_factor: Optional[int] = None, - write_consistency_factor: Optional[int] = None, - on_disk_payload: Optional[bool] = None, - hnsw_config: Optional[common_types.HnswConfigDiff] = None, - optimizers_config: Optional[common_types.OptimizersConfigDiff] = None, - wal_config: Optional[common_types.WalConfigDiff] = None, - quantization_config: Optional[common_types.QuantizationConfig] = None, - init_from: Optional[common_types.InitFrom] = None, - on_disk: Optional[bool] = None, - force_recreate: bool = False, - **kwargs: Any, - ) -> Qdrant: - try: - import qdrant_client # noqa - except ImportError: - raise ImportError( - "Could not import qdrant-client python package. " - "Please install it with `pip install qdrant-client`." - ) - from grpc import RpcError - from qdrant_client.http import models as rest - from qdrant_client.http.exceptions import UnexpectedResponse - - # Just do a single quick embedding to get vector size - partial_embeddings = await embedding.aembed_documents(texts[:1]) - vector_size = len(partial_embeddings[0]) - collection_name = collection_name or uuid.uuid4().hex - distance_func = distance_func.upper() - client, async_client = cls._generate_clients( - location=location, - url=url, - port=port, - grpc_port=grpc_port, - prefer_grpc=prefer_grpc, - https=https, - api_key=api_key, - prefix=prefix, - timeout=timeout, - host=host, - path=path, - **kwargs, - ) - try: - # Skip any validation in case of forced collection recreate. - if force_recreate: - raise ValueError - - # Get the vector configuration of the existing collection and vector, if it - # was specified. If the old configuration does not match the current one, - # an exception is being thrown. - collection_info = client.get_collection(collection_name=collection_name) - current_vector_config = collection_info.config.params.vectors - if isinstance(current_vector_config, dict) and vector_name is not None: - if vector_name not in current_vector_config: - raise QdrantException( - f"Existing Qdrant collection {collection_name} does not " - f"contain vector named {vector_name}. Did you mean one of the " - f"existing vectors: {', '.join(current_vector_config.keys())}? " - f"If you want to recreate the collection, set `force_recreate` " - f"parameter to `True`." - ) - current_vector_config = current_vector_config.get(vector_name) - elif isinstance(current_vector_config, dict) and vector_name is None: - raise QdrantException( - f"Existing Qdrant collection {collection_name} uses named vectors. " - f"If you want to reuse it, please set `vector_name` to any of the " - f"existing named vectors: " - f"{', '.join(current_vector_config.keys())}." - f"If you want to recreate the collection, set `force_recreate` " - f"parameter to `True`." - ) - elif ( - not isinstance(current_vector_config, dict) and vector_name is not None - ): - raise QdrantException( - f"Existing Qdrant collection {collection_name} doesn't use named " - f"vectors. If you want to reuse it, please set `vector_name` to " - f"`None`. If you want to recreate the collection, set " - f"`force_recreate` parameter to `True`." - ) - - # Check if the vector configuration has the same dimensionality. - if current_vector_config.size != vector_size: - raise QdrantException( - f"Existing Qdrant collection is configured for vectors with " - f"{current_vector_config.size} " - f"dimensions. Selected embeddings are {vector_size}-dimensional. " - f"If you want to recreate the collection, set `force_recreate` " - f"parameter to `True`." - ) - - current_distance_func = current_vector_config.distance.name.upper() - if current_distance_func != distance_func: - raise QdrantException( - f"Existing Qdrant collection is configured for " - f"{current_vector_config.distance} " - f"similarity. Please set `distance_func` parameter to " - f"`{distance_func}` if you want to reuse it. If you want to " - f"recreate the collection, set `force_recreate` parameter to " - f"`True`." - ) - except (UnexpectedResponse, RpcError, ValueError): - vectors_config = rest.VectorParams( - size=vector_size, - distance=rest.Distance[distance_func], - on_disk=on_disk, - ) - - # If vector name was provided, we're going to use the named vectors feature - # with just a single vector. - if vector_name is not None: - vectors_config = { - vector_name: vectors_config, - } - - client.recreate_collection( - collection_name=collection_name, - vectors_config=vectors_config, - shard_number=shard_number, - replication_factor=replication_factor, - write_consistency_factor=write_consistency_factor, - on_disk_payload=on_disk_payload, - hnsw_config=hnsw_config, - optimizers_config=optimizers_config, - wal_config=wal_config, - quantization_config=quantization_config, - init_from=init_from, - timeout=timeout, - ) - qdrant = cls( - client=client, - collection_name=collection_name, - embeddings=embedding, - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - distance_strategy=distance_func, - vector_name=vector_name, - async_client=async_client, - ) - return qdrant - - @staticmethod - def _cosine_relevance_score_fn(distance: float) -> float: - """Normalize the distance to a score on a scale [0, 1].""" - return (distance + 1.0) / 2.0 - - def _select_relevance_score_fn(self) -> Callable[[float], float]: - """ - The 'correct' relevance function - may differ depending on a few things, including: - - the distance / similarity metric used by the VectorStore - - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) - - embedding dimensionality - - etc. - """ - - if self.distance_strategy == "COSINE": - return self._cosine_relevance_score_fn - elif self.distance_strategy == "DOT": - return self._max_inner_product_relevance_score_fn - elif self.distance_strategy == "EUCLID": - return self._euclidean_relevance_score_fn - else: - raise ValueError( - "Unknown distance strategy, must be cosine, " - "max_inner_product, or euclidean" - ) - - def _similarity_search_with_relevance_scores( - self, - query: str, - k: int = 4, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs and relevance scores in the range [0, 1]. - - 0 is dissimilar, 1 is most similar. - - Args: - query: input text - k: Number of Documents to return. Defaults to 4. - **kwargs: kwargs to be passed to similarity search. Should include: - score_threshold: Optional, a floating point value between 0 to 1 to - filter the resulting set of retrieved docs - - Returns: - List of Tuples of (doc, similarity_score) - """ - return self.similarity_search_with_score(query, k, **kwargs) - - @classmethod - def _build_payloads( - cls, - texts: Iterable[str], - metadatas: Optional[List[dict]], - content_payload_key: str, - metadata_payload_key: str, - ) -> List[dict]: - payloads = [] - for i, text in enumerate(texts): - if text is None: - raise ValueError( - "At least one of the texts is None. Please remove it before " - "calling .from_texts or .add_texts on Qdrant instance." - ) - metadata = metadatas[i] if metadatas is not None else None - payloads.append( - { - content_payload_key: text, - metadata_payload_key: metadata, - } - ) - - return payloads - - @classmethod - def _document_from_scored_point( - cls, - scored_point: Any, - collection_name: str, - content_payload_key: str, - metadata_payload_key: str, - ) -> Document: - metadata = scored_point.payload.get(metadata_payload_key) or {} - metadata["_id"] = scored_point.id - metadata["_collection_name"] = collection_name - return Document( - page_content=scored_point.payload.get(content_payload_key), - metadata=metadata, - ) - - def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]: - from qdrant_client.http import models as rest - - out = [] - - if isinstance(value, dict): - for _key, value in value.items(): - out.extend(self._build_condition(f"{key}.{_key}", value)) - elif isinstance(value, list): - for _value in value: - if isinstance(_value, dict): - out.extend(self._build_condition(f"{key}[]", _value)) - else: - out.extend(self._build_condition(f"{key}", _value)) - else: - out.append( - rest.FieldCondition( - key=f"{self.metadata_payload_key}.{key}", - match=rest.MatchValue(value=value), - ) - ) - - return out - - def _qdrant_filter_from_dict( - self, filter: Optional[DictFilter] - ) -> Optional[rest.Filter]: - from qdrant_client.http import models as rest - - if not filter: - return None - - return rest.Filter( - must=[ - condition - for key, value in filter.items() - for condition in self._build_condition(key, value) - ] - ) - - def _embed_query(self, query: str) -> List[float]: - """Embed query text. - - Used to provide backward compatibility with `embedding_function` argument. - - Args: - query: Query text. - - Returns: - List of floats representing the query embedding. - """ - if self.embeddings is not None: - embedding = self.embeddings.embed_query(query) - else: - if self._embeddings_function is not None: - embedding = self._embeddings_function(query) - else: - raise ValueError("Neither of embeddings or embedding_function is set") - return embedding.tolist() if hasattr(embedding, "tolist") else embedding - - async def _aembed_query(self, query: str) -> List[float]: - """Embed query text asynchronously. - - Used to provide backward compatibility with `embedding_function` argument. - - Args: - query: Query text. - - Returns: - List of floats representing the query embedding. - """ - if self.embeddings is not None: - embedding = await self.embeddings.aembed_query(query) - else: - if self._embeddings_function is not None: - embedding = self._embeddings_function(query) - else: - raise ValueError("Neither of embeddings or embedding_function is set") - return embedding.tolist() if hasattr(embedding, "tolist") else embedding - - def _embed_texts(self, texts: Iterable[str]) -> List[List[float]]: - """Embed search texts. - - Used to provide backward compatibility with `embedding_function` argument. - - Args: - texts: Iterable of texts to embed. - - Returns: - List of floats representing the texts embedding. - """ - if self.embeddings is not None: - embeddings = self.embeddings.embed_documents(list(texts)) - if hasattr(embeddings, "tolist"): - embeddings = embeddings.tolist() - elif self._embeddings_function is not None: - embeddings = [] - for text in texts: - embedding = self._embeddings_function(text) - if hasattr(embeddings, "tolist"): - embedding = embedding.tolist() - embeddings.append(embedding) - else: - raise ValueError("Neither of embeddings or embedding_function is set") - - return embeddings - - async def _aembed_texts(self, texts: Iterable[str]) -> List[List[float]]: - """Embed search texts. - - Used to provide backward compatibility with `embedding_function` argument. - - Args: - texts: Iterable of texts to embed. - - Returns: - List of floats representing the texts embedding. - """ - if self.embeddings is not None: - embeddings = await self.embeddings.aembed_documents(list(texts)) - if hasattr(embeddings, "tolist"): - embeddings = embeddings.tolist() - elif self._embeddings_function is not None: - embeddings = [] - for text in texts: - embedding = self._embeddings_function(text) - if hasattr(embeddings, "tolist"): - embedding = embedding.tolist() - embeddings.append(embedding) - else: - raise ValueError("Neither of embeddings or embedding_function is set") - - return embeddings - - def _generate_rest_batches( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[Sequence[str]] = None, - batch_size: int = 64, - ) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]: - from qdrant_client.http import models as rest - - texts_iterator = iter(texts) - metadatas_iterator = iter(metadatas or []) - ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) - while batch_texts := list(islice(texts_iterator, batch_size)): - # Take the corresponding metadata and id for each text in a batch - batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None - batch_ids = list(islice(ids_iterator, batch_size)) - - # Generate the embeddings for all the texts in a batch - batch_embeddings = self._embed_texts(batch_texts) - - points = [ - rest.PointStruct( - id=point_id, - vector=vector - if self.vector_name is None - else {self.vector_name: vector}, - payload=payload, - ) - for point_id, vector, payload in zip( - batch_ids, - batch_embeddings, - self._build_payloads( - batch_texts, - batch_metadatas, - self.content_payload_key, - self.metadata_payload_key, - ), - ) - ] - - yield batch_ids, points - - async def _agenerate_rest_batches( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[Sequence[str]] = None, - batch_size: int = 64, - ) -> AsyncGenerator[Tuple[List[str], List[rest.PointStruct]], None]: - from qdrant_client.http import models as rest - - texts_iterator = iter(texts) - metadatas_iterator = iter(metadatas or []) - ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) - while batch_texts := list(islice(texts_iterator, batch_size)): - # Take the corresponding metadata and id for each text in a batch - batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None - batch_ids = list(islice(ids_iterator, batch_size)) - - # Generate the embeddings for all the texts in a batch - batch_embeddings = await self._aembed_texts(batch_texts) - - points = [ - rest.PointStruct( - id=point_id, - vector=vector - if self.vector_name is None - else {self.vector_name: vector}, - payload=payload, - ) - for point_id, vector, payload in zip( - batch_ids, - batch_embeddings, - self._build_payloads( - batch_texts, - batch_metadatas, - self.content_payload_key, - self.metadata_payload_key, - ), - ) - ] - - yield batch_ids, points - - @staticmethod - def _generate_clients( - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, - grpc_port: int = 6334, - prefer_grpc: bool = False, - https: Optional[bool] = None, - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[float] = None, - host: Optional[str] = None, - path: Optional[str] = None, - **kwargs: Any, - ) -> Tuple[Any, Any]: - from qdrant_client import AsyncQdrantClient, QdrantClient - - sync_client = QdrantClient( - location=location, - url=url, - port=port, - grpc_port=grpc_port, - prefer_grpc=prefer_grpc, - https=https, - api_key=api_key, - prefix=prefix, - timeout=timeout, - host=host, - path=path, - **kwargs, - ) - - if location == ":memory:" or path is not None: - # Local Qdrant cannot co-exist with Sync and Async clients - # We fallback to sync operations in this case - async_client = None - else: - async_client = AsyncQdrantClient( - location=location, - url=url, - port=port, - grpc_port=grpc_port, - prefer_grpc=prefer_grpc, - https=https, - api_key=api_key, - prefix=prefix, - timeout=timeout, - host=host, - path=path, - **kwargs, - ) - - return sync_client, async_client diff --git a/libs/community/langchain_community/vectorstores/starrocks.py b/libs/community/langchain_community/vectorstores/starrocks.py index 8c24a794c..ea3d921b2 100644 --- a/libs/community/langchain_community/vectorstores/starrocks.py +++ b/libs/community/langchain_community/vectorstores/starrocks.py @@ -162,7 +162,7 @@ def __init__( config (StarRocksSettings): Configuration to StarRocks Client """ try: - import pymysql # type: ignore[import-untyped] + import pymysql # type: ignore[import-untyped, unused-ignore] except ImportError: raise ImportError( "Could not import pymysql python package. " @@ -190,7 +190,7 @@ def __init__( dim = len(embedding.embed_query("test")) self.schema = f"""\ -CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( +CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( {self.config.column_map["id"]} string, {self.config.column_map["document"]} string, {self.config.column_map["embedding"]} array, @@ -378,10 +378,10 @@ def _build_query_sql( where_str = "" q_str = f""" - SELECT + SELECT id as id, - {self.config.column_map["document"]} as document, - {self.config.column_map["metadata"]} as metadata, + {self.config.column_map["document"]} as document, + {self.config.column_map["metadata"]} as metadata, cosine_similarity_norm(array[{q_emb_str}], {self.config.column_map["embedding"]}) as dist, {self.config.column_map["embedding"]} as embedding diff --git a/libs/community/langchain_community/vectorstores/vdms.py b/libs/community/langchain_community/vectorstores/vdms.py deleted file mode 100644 index 9c010c7c3..000000000 --- a/libs/community/langchain_community/vectorstores/vdms.py +++ /dev/null @@ -1,1746 +0,0 @@ -from __future__ import annotations - -import base64 -import logging -import os -import uuid -from copy import deepcopy -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - List, - Literal, - Optional, - Sized, - Tuple, - Type, - Union, - get_args, -) - -import numpy as np -from langchain_core._api.deprecation import deprecated -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from langchain_core.vectorstores import VectorStore - -from langchain_community.vectorstores.utils import maximal_marginal_relevance - -if TYPE_CHECKING: - import vdms - - -DISTANCE_METRICS = Literal[ - "L2", # Euclidean Distance - "IP", # Inner Product -] -AVAILABLE_DISTANCE_METRICS: List[DISTANCE_METRICS] = list(get_args(DISTANCE_METRICS)) -ENGINES = Literal[ - "TileDBDense", # TileDB Dense - "TileDBSparse", # TileDB Sparse - "FaissFlat", # FAISS IndexFlat - "FaissIVFFlat", # FAISS IndexIVFFlat - "Flinng", # FLINNG -] -AVAILABLE_ENGINES: List[ENGINES] = list(get_args(ENGINES)) -DEFAULT_COLLECTION_NAME = "langchain" -DEFAULT_INSERT_BATCH_SIZE = 32 -# Number of Documents to return. -DEFAULT_K = 3 -# Number of Documents to fetch to pass to knn when filters applied. -DEFAULT_FETCH_K = DEFAULT_K * 5 -DEFAULT_PROPERTIES = ["_distance", "id", "content"] -INVALID_DOC_METADATA_KEYS = ["_distance", "content", "blob"] -INVALID_METADATA_VALUE = ["Missing property", None, {}] # type: List - - -logger = logging.getLogger(__name__) - - -def _len_check_if_sized(x: Any, y: Any, x_name: str, y_name: str) -> None: - """ - Check that sizes of two variables are the same - - Args: - x: Variable to compare - y: Variable to compare - x_name: Name for variable x - y_name: Name for variable y - """ - if isinstance(x, Sized) and isinstance(y, Sized) and len(x) != len(y): - raise ValueError( - f"{x_name} and {y_name} expected to be equal length but " - f"len({x_name})={len(x)} and len({y_name})={len(y)}" - ) - return - - -def _results_to_docs(results: Any) -> List[Document]: - return [doc for doc, _ in _results_to_docs_and_scores(results)] - - -def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]: - final_res: List[Any] = [] - try: - responses, blobs = results[0] - if ( - len(responses) > 0 - and "FindDescriptor" in responses[0] - and "entities" in responses[0]["FindDescriptor"] - ): - result_entities = responses[0]["FindDescriptor"]["entities"] - # result_blobs = blobs - for ent in result_entities: - distance = round(ent["_distance"], 10) - txt_contents = ent["content"] - for p in INVALID_DOC_METADATA_KEYS: - if p in ent: - del ent[p] - props = { - mkey: mval - for mkey, mval in ent.items() - if mval not in INVALID_METADATA_VALUE - } - - final_res.append( - ( - Document(page_content=txt_contents, metadata=props), - distance, - ) - ) - except Exception as e: - logger.warning(f"No results returned. Error while parsing results: {e}") - return final_res - - -def VDMS_Client(host: str = "localhost", port: int = 55555) -> vdms.vdms: - """VDMS client for the VDMS server. - - Args: - host: IP or hostname of VDMS server - port: Port to connect to VDMS server - """ - try: - import vdms - except ImportError: - raise ImportError( - "Could not import vdms python package. " - "Please install it with `pip install vdms." - ) - - client = vdms.vdms() - client.connect(host, port) - return client - - -@deprecated(since="0.3.18", removal="1.0.0", alternative_import="langchain_vdms.VDMS") -class VDMS(VectorStore): - """Intel Lab's VDMS for vector-store workloads. - - To use, you should have both: - - the ``vdms`` python package installed - - a host (str) and port (int) associated with a deployed VDMS Server - - Visit https://github.com/IntelLabs/vdms/wiki more information. - - IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA. - - Args: - client: VDMS Client used to connect to VDMS server - collection_name: Name of data collection [Default: langchain] - distance_strategy: Method used to calculate distances. VDMS supports - "L2" (euclidean distance) or "IP" (inner product) [Default: L2] - engine: Underlying implementation for indexing and computing distances. - VDMS supports TileDBDense, TileDBSparse, FaissFlat, FaissIVFFlat, - and Flinng [Default: FaissFlat] - embedding: Any embedding function implementing - `langchain_core.embeddings.Embeddings` interface. - relevance_score_fn: Function for obtaining relevance score - - Example: - .. code-block:: python - - from langchain_huggingface import HuggingFaceEmbeddings - from langchain_community.vectorstores.vdms import VDMS, VDMS_Client - - model_name = "sentence-transformers/all-mpnet-base-v2" - vectorstore = VDMS( - client=VDMS_Client("localhost", 55555), - embedding=HuggingFaceEmbeddings(model_name=model_name), - collection_name="langchain-demo", - distance_strategy="L2", - engine="FaissFlat", - ) - """ - - def __init__( - self, - client: vdms.vdms, - *, - embedding: Optional[Embeddings] = None, - collection_name: str = DEFAULT_COLLECTION_NAME, # DescriptorSet name - distance_strategy: DISTANCE_METRICS = "L2", - engine: ENGINES = "FaissFlat", - relevance_score_fn: Optional[Callable[[float], float]] = None, - embedding_dimensions: Optional[int] = None, - ) -> None: - # Check required parameters - self._client = client - self.similarity_search_engine = engine - self.distance_strategy = distance_strategy - self.embedding = embedding - self._check_required_inputs(collection_name, embedding_dimensions) - - # Update other parameters - self.override_relevance_score_fn = relevance_score_fn - - # Initialize collection - self._collection_name = self.add_set( - collection_name, - engine=self.similarity_search_engine, - metric=self.distance_strategy, - ) - - @property - def embeddings(self) -> Optional[Embeddings]: - return self.embedding - - def _embed_documents(self, texts: List[str]) -> List[List[float]]: - if isinstance(self.embedding, Embeddings): - return self.embedding.embed_documents(texts) - else: - p_str = "Must provide `embedding` which is expected" - p_str += " to be an Embeddings object" - raise ValueError(p_str) - - def _embed_video(self, paths: List[str], **kwargs: Any) -> List[List[float]]: - if self.embedding is not None and hasattr(self.embedding, "embed_video"): - return self.embedding.embed_video(paths=paths, **kwargs) - else: - raise ValueError( - "Must provide `embedding` which has attribute `embed_video`" - ) - - def _embed_image(self, uris: List[str]) -> List[List[float]]: - if self.embedding is not None and hasattr(self.embedding, "embed_image"): - return self.embedding.embed_image(uris=uris) - else: - raise ValueError( - "Must provide `embedding` which has attribute `embed_image`" - ) - - def _embed_query(self, text: str) -> List[float]: - if isinstance(self.embedding, Embeddings): - return self.embedding.embed_query(text) - else: - raise ValueError( - "Must provide `embedding` which is expected to be an Embeddings object" - ) - - def _select_relevance_score_fn(self) -> Callable[[float], float]: - """ - The 'correct' relevance function - may differ depending on a few things, including: - - the distance / similarity metric used by the VectorStore - - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) - - embedding dimensionality - - etc. - """ - if self.override_relevance_score_fn is not None: - return self.override_relevance_score_fn - - # Default strategy is to rely on distance strategy provided - # in vectorstore constructor - if self.distance_strategy.lower() in ["ip", "l2"]: - return lambda x: x - else: - raise ValueError( - "No supported normalization function" - f" for distance_strategy of {self.distance_strategy}." - "Consider providing relevance_score_fn to VDMS constructor." - ) - - def _similarity_search_with_relevance_scores( - self, - query: str, - k: int = DEFAULT_K, - fetch_k: int = DEFAULT_FETCH_K, - filter: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs and their similarity scores on a scale from 0 to 1.""" - if self.override_relevance_score_fn is None: - kwargs["normalize_distance"] = True - docs_and_scores = self.similarity_search_with_score( - query=query, - k=k, - fetch_k=fetch_k, - filter=filter, - **kwargs, - ) - - docs_and_rel_scores: List[Any] = [] - for doc, score in docs_and_scores: - if self.override_relevance_score_fn is None: - docs_and_rel_scores.append((doc, score)) - else: - docs_and_rel_scores.append( - ( - doc, - self.override_relevance_score_fn(score), - ) - ) - return docs_and_rel_scores - - def add( - self, - collection_name: str, - texts: List[str], - embeddings: List[List[float]], - metadatas: Optional[Union[List[None], List[Dict[str, Any]]]] = None, - ids: Optional[List[str]] = None, - ) -> List: - _len_check_if_sized(texts, embeddings, "texts", "embeddings") - - metadatas = metadatas if metadatas is not None else [None for _ in texts] - _len_check_if_sized(texts, metadatas, "texts", "metadatas") - - ids = ids if ids is not None else [str(uuid.uuid4()) for _ in texts] - _len_check_if_sized(texts, ids, "texts", "ids") - - all_queries: List[Any] = [] - all_blobs: List[Any] = [] - inserted_ids: List[Any] = [] - for meta, emb, doc, id in zip(metadatas, embeddings, texts, ids): - query, blob = self.__get_add_query( - collection_name, metadata=meta, embedding=emb, document=doc, id=id - ) - - if blob is not None: - all_queries.append(query) - all_blobs.append(blob) - inserted_ids.append(id) - - response, response_array = self.__run_vdms_query(all_queries, all_blobs) - - return inserted_ids - - def add_set( - self, - collection_name: str, - engine: ENGINES = "FaissFlat", - metric: DISTANCE_METRICS = "L2", - ) -> str: - query = _add_descriptorset( - "AddDescriptorSet", - collection_name, - self.embedding_dimension, - engine=getattr(engine, "value", engine), - metric=getattr(metric, "value", metric), - ) - - response, _ = self.__run_vdms_query([query]) - - if "FailedCommand" in response[0]: - raise ValueError(f"Failed to add collection {collection_name}") - - return collection_name - - def __delete( - self, - collection_name: str, - ids: Union[None, List[str]] = None, - constraints: Union[None, Dict[str, Any]] = None, - ) -> bool: - """ - Deletes entire collection if id is not provided - """ - all_queries: List[Any] = [] - all_blobs: List[Any] = [] - - collection_properties = self.__get_properties(collection_name) - results = {"list": collection_properties} - - if constraints is None: - constraints = {"_deletion": ["==", 1]} - else: - constraints["_deletion"] = ["==", 1] - - if ids is not None: - constraints["id"] = ["==", ids[0]] # if len(ids) > 1 else ids[0]] - - query = _add_descriptor( - "FindDescriptor", - collection_name, - label=None, - ref=None, - props=None, - link=None, - k_neighbors=None, - constraints=constraints, - results=results, - ) - - all_queries.append(query) - response, response_array = self.__run_vdms_query(all_queries, all_blobs) - - # Update/store indices after deletion - query = _add_descriptorset( - "FindDescriptorSet", collection_name, storeIndex=True - ) - responseSet, _ = self.__run_vdms_query([query], all_blobs) - return "FindDescriptor" in response[0] - - def __get_add_query( - self, - collection_name: str, - metadata: Optional[Any] = None, - embedding: Union[List[float], None] = None, - document: Optional[Any] = None, - id: Optional[str] = None, - ) -> Tuple[Dict[str, Dict[str, Any]], Union[bytes, None]]: - if id is None: - props: Dict[str, Any] = {} - else: - props = {"id": id} - id_exists, query = _check_descriptor_exists_by_id( - self._client, collection_name, id - ) - if id_exists: - skipped_value = { - prop_key: prop_val[-1] - for prop_key, prop_val in query["FindDescriptor"][ - "constraints" - ].items() - } - pstr = f"[!] Embedding with id ({id}) exists in DB;" - pstr += "Therefore, skipped and not inserted" - print(pstr) # noqa: T201 - print(f"\tSkipped values are: {skipped_value}") # noqa: T201 - return query, None - - if metadata: - props.update(metadata) - if document not in [None, ""]: - props["content"] = document - - for k in props.keys(): - if k not in self.collection_properties: - self.collection_properties.append(k) - - query = _add_descriptor( - "AddDescriptor", - collection_name, - label=None, - ref=None, - props=props, - link=None, - k_neighbors=None, - constraints=None, - results=None, - ) - - blob = embedding2bytes(embedding) - - return ( - query, - blob, - ) - - def __get_properties( - self, - collection_name: str, - unique_entity: Optional[bool] = False, - deletion: Optional[bool] = False, - ) -> List[str]: - find_query = _find_property_entity( - collection_name, unique_entity=unique_entity, deletion=deletion - ) - response, response_blob = self.__run_vdms_query([find_query]) - if len(response_blob) > 0: - collection_properties = _bytes2str(response_blob[0]).split(",") - else: - collection_properties = deepcopy(DEFAULT_PROPERTIES) - return collection_properties - - def __run_vdms_query( - self, - all_queries: List[Dict], - all_blobs: Optional[List] = [], - print_last_response: Optional[bool] = False, - ) -> Tuple[Any, Any]: - response, response_array = self._client.query(all_queries, all_blobs) - - _ = _check_valid_response(all_queries, response) - if print_last_response: - self._client.print_last_response() - return response, response_array - - def __update( - self, - collection_name: str, - ids: List[str], - documents: List[str], - embeddings: List[List[float]], - metadatas: Optional[Union[List[None], List[Dict[str, Any]]]] = None, - ) -> None: - """ - Updates (find, delete, add) a collection based on id. - If more than one collection returned with id, error occuers - """ - _len_check_if_sized(ids, documents, "ids", "documents") - - _len_check_if_sized(ids, embeddings, "ids", "embeddings") - - metadatas = metadatas if metadatas is not None else [None for _ in ids] - _len_check_if_sized(ids, metadatas, "ids", "metadatas") - - orig_props = self.__get_properties(collection_name) - - updated_ids: List[Any] = [] - for meta, emb, doc, id in zip(metadatas, embeddings, documents, ids): - results = {"list": self.collection_properties} - - constraints = {"_deletion": ["==", 1]} - - if id is not None: - constraints["id"] = ["==", id] - - query = _add_descriptor( - "FindDescriptor", - collection_name, - label=None, - ref=None, - props=None, - link=None, - k_neighbors=None, - constraints=constraints, - results=results, - ) - - response, response_array = self.__run_vdms_query([query]) - - query, blob = self.__get_add_query( - collection_name, - metadata=meta, - embedding=emb, - document=doc, - id=id, - ) - if blob is not None: - response, response_array = self.__run_vdms_query([query], [blob]) - updated_ids.append(id) - - self.__update_properties( - collection_name, orig_props, self.collection_properties - ) - - def __update_properties( - self, - collection_name: str, - current_collection_properties: List, - new_collection_properties: Optional[List], - ) -> None: - if new_collection_properties is not None: - old_collection_properties = deepcopy(current_collection_properties) - for prop in new_collection_properties: - if prop not in current_collection_properties: - current_collection_properties.append(prop) - - if current_collection_properties != old_collection_properties: - all_queries, blob_arr = _build_property_query( - collection_name, - command_type="update", - all_properties=current_collection_properties, - ) - response, _ = self.__run_vdms_query(all_queries, [blob_arr]) - - def add_images( - self, - uris: List[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - batch_size: int = DEFAULT_INSERT_BATCH_SIZE, - add_path: Optional[bool] = True, - **kwargs: Any, - ) -> List[str]: - """Run more images through the embeddings and add to the vectorstore. - - Images are added as embeddings (AddDescriptor) instead of separate - entity (AddImage) within VDMS to leverage similarity search capability - - Args: - uris: List of paths to the images to add to the vectorstore. - metadatas: Optional list of metadatas associated with the images. - ids: Optional list of unique IDs. - batch_size (int): Number of concurrent requests to send to the server. - add_path: Bool to add image path as metadata - - Returns: - List of ids from adding images into the vectorstore. - """ - # Map from uris to blobs to base64 - b64_texts = [self.encode_image(image_path=uri) for uri in uris] - - if add_path and metadatas: - for midx, uri in enumerate(uris): - metadatas[midx]["image_path"] = uri - elif add_path: - metadatas = [] - for uri in uris: - metadatas.append({"image_path": uri}) - - # Populate IDs - ids = ids if ids is not None else [str(uuid.uuid4()) for _ in uris] - - # Set embeddings - embeddings = self._embed_image(uris=uris) - - if metadatas is None: - metadatas = [{} for _ in uris] - else: - metadatas = [_validate_vdms_properties(m) for m in metadatas] - - self.add_from( - texts=b64_texts, - embeddings=embeddings, - ids=ids, - metadatas=metadatas, - batch_size=batch_size, - **kwargs, - ) - return ids - - def add_videos( - self, - paths: List[str], - texts: Optional[List[str]] = None, - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - batch_size: int = 1, - add_path: Optional[bool] = True, - **kwargs: Any, - ) -> List[str]: - """Run videos through the embeddings and add to the vectorstore. - - Videos are added as embeddings (AddDescriptor) instead of separate - entity (AddVideo) within VDMS to leverage similarity search capability - - Args: - paths: List of paths to the videos to add to the vectorstore. - metadatas: Optional list of text associated with the videos. - metadatas: Optional list of metadatas associated with the videos. - ids: Optional list of unique IDs. - batch_size (int): Number of concurrent requests to send to the server. - add_path: Bool to add video path as metadata - - Returns: - List of ids from adding videos into the vectorstore. - """ - if texts is None: - texts = ["" for _ in paths] - - if add_path and metadatas: - for midx, path in enumerate(paths): - metadatas[midx]["video_path"] = path - elif add_path: - metadatas = [] - for path in paths: - metadatas.append({"video_path": path}) - - # Populate IDs - ids = ids if ids is not None else [str(uuid.uuid4()) for _ in paths] - - # Set embeddings - embeddings = self._embed_video(paths=paths, **kwargs) - - if metadatas is None: - metadatas = [{} for _ in paths] - - self.add_from( - texts=texts, - embeddings=embeddings, - ids=ids, - metadatas=metadatas, - batch_size=batch_size, - **kwargs, - ) - return ids - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - batch_size: int = DEFAULT_INSERT_BATCH_SIZE, - **kwargs: Any, - ) -> List[str]: - """Run more texts through the embeddings and add to the vectorstore. - - Args: - texts: List of strings to add to the vectorstore. - metadatas: Optional list of metadatas associated with the texts. - ids: Optional list of unique IDs. - batch_size (int): Number of concurrent requests to send to the server. - - Returns: - List of ids from adding the texts into the vectorstore. - """ - - texts = list(texts) - if ids is None: - ids = [str(uuid.uuid4()) for _ in texts] - - embeddings = self._embed_documents(texts) - - if metadatas is None: - metadatas = [{} for _ in texts] - else: - metadatas = [_validate_vdms_properties(m) for m in metadatas] - - inserted_ids = self.add_from( - texts=texts, - embeddings=embeddings, - ids=ids, - metadatas=metadatas, - batch_size=batch_size, - **kwargs, - ) - return inserted_ids - - def add_from( - self, - texts: List[str], - embeddings: List[List[float]], - ids: List[str], - metadatas: Optional[List[dict]] = None, - batch_size: int = DEFAULT_INSERT_BATCH_SIZE, - **kwargs: Any, - ) -> List[str]: - # Get initial properties - orig_props = self.__get_properties(self._collection_name) - inserted_ids: List[str] = [] - for start_idx in range(0, len(texts), batch_size): - end_idx = min(start_idx + batch_size, len(texts)) - - batch_texts = texts[start_idx:end_idx] - batch_embedding_vectors = embeddings[start_idx:end_idx] - batch_ids = ids[start_idx:end_idx] - if metadatas: - batch_metadatas = metadatas[start_idx:end_idx] - - result = self.add( - self._collection_name, - embeddings=batch_embedding_vectors, - texts=batch_texts, - metadatas=batch_metadatas, - ids=batch_ids, - ) - - inserted_ids.extend(result) - - # Update Properties - self.__update_properties( - self._collection_name, orig_props, self.collection_properties - ) - return inserted_ids - - def _check_required_inputs( - self, collection_name: str, embedding_dimensions: Union[int, None] - ) -> None: - # Check connection to client - if not self._client.is_connected(): - raise ValueError( - "VDMS client must be connected to a VDMS server." - + "Please use VDMS_Client to establish a connection" - ) - - # Check Distance Metric - if self.distance_strategy not in AVAILABLE_DISTANCE_METRICS: - raise ValueError("distance_strategy must be either 'L2' or 'IP'") - - # Check Engines - if self.similarity_search_engine not in AVAILABLE_ENGINES: - raise ValueError( - "engine must be either 'TileDBDense', 'TileDBSparse', " - + "'FaissFlat', 'FaissIVFFlat', or 'Flinng'" - ) - - # Check Embedding Func is provided and store dimension size - if self.embedding is None: - raise ValueError("Must provide embedding function") - - if embedding_dimensions is not None: - self.embedding_dimension = embedding_dimensions - elif self.embedding is not None and hasattr(self.embedding, "embed_query"): - self.embedding_dimension = len( - self._embed_query("This is a sample sentence.") - ) - elif self.embedding is not None and ( - hasattr(self.embedding, "embed_image") - or hasattr(self.embedding, "embed_video") - ): - if hasattr(self.embedding, "model"): - try: - self.embedding_dimension = ( - self.embedding.model.token_embedding.embedding_dim - ) - except ValueError: - raise ValueError( - "Embedding dimension needed. Please define embedding_dimensions" - ) - else: - raise ValueError( - "Embedding dimension needed. Please define embedding_dimensions" - ) - - # Check for properties - current_props = self.__get_properties(collection_name) - if hasattr(self, "collection_properties"): - self.collection_properties.extend(current_props) - else: - self.collection_properties: List[str] = current_props - - def count(self, collection_name: str) -> int: - all_queries: List[Any] = [] - all_blobs: List[Any] = [] - - results = {"count": "", "list": ["id"]} # collection_properties} - query = _add_descriptor( - "FindDescriptor", - collection_name, - label=None, - ref=None, - props=None, - link=None, - k_neighbors=None, - constraints=None, - results=results, - ) - - all_queries.append(query) - - response, response_array = self.__run_vdms_query(all_queries, all_blobs) - return response[0]["FindDescriptor"]["returned"] - - def decode_image(self, base64_image: str) -> bytes: - return base64.b64decode(base64_image) - - def delete( - self, - ids: Optional[List[str]] = None, - collection_name: Optional[str] = None, - constraints: Optional[Dict] = None, - **kwargs: Any, - ) -> bool: - """Delete by ID. These are the IDs in the vectorstore. - - Args: - ids: List of ids to delete. - - Returns: - Optional[bool]: True if deletion is successful, - False otherwise, None if not implemented. - """ - name = collection_name if collection_name is not None else self._collection_name - return self.__delete(name, ids=ids, constraints=constraints) - - def get_k_candidates( - self, - setname: str, - fetch_k: Optional[int], - results: Optional[Dict[str, Any]] = None, - all_blobs: Optional[List] = None, - normalize: Optional[bool] = False, - ) -> Tuple[List[Dict[str, Any]], List, float]: - max_dist = 1 - command_str = "FindDescriptor" - query = _add_descriptor( - command_str, - setname, - k_neighbors=fetch_k, - results=results, - ) - response, response_array = self.__run_vdms_query([query], all_blobs) - - if normalize and command_str in response[0]: - max_dist = response[0][command_str]["entities"][-1]["_distance"] - - return response, response_array, max_dist - - def get_descriptor_response( - self, - command_str: str, - setname: str, - k_neighbors: int = DEFAULT_K, - fetch_k: int = DEFAULT_FETCH_K, - constraints: Optional[dict] = None, - results: Optional[Dict[str, Any]] = None, - query_embedding: Optional[List[float]] = None, - normalize_distance: bool = False, - ) -> Tuple[List[Dict[str, Any]], List]: - all_blobs: List[Any] = [] - blob = embedding2bytes(query_embedding) - if blob is not None: - all_blobs.append(blob) - - if constraints is None: - # K results returned - response, response_array, max_dist = self.get_k_candidates( - setname, k_neighbors, results, all_blobs, normalize=normalize_distance - ) - else: - if results is None: - results = {"list": ["id"]} - elif "list" not in results: - results["list"] = ["id"] - elif "id" not in results["list"]: - results["list"].append("id") - - # (1) Find docs satisfy constraints - query = _add_descriptor( - command_str, - setname, - constraints=constraints, - results=results, - ) - response, response_array = self.__run_vdms_query([query]) - if command_str in response[0] and response[0][command_str]["returned"] > 0: - ids_of_interest = [ - ent["id"] for ent in response[0][command_str]["entities"] - ] - else: - return [], [] - - # (2) Find top fetch_k results - response, response_array, max_dist = self.get_k_candidates( - setname, fetch_k, results, all_blobs, normalize=normalize_distance - ) - if command_str not in response[0] or ( - command_str in response[0] and response[0][command_str]["returned"] == 0 - ): - return [], [] - - # (3) Intersection of (1) & (2) using ids - new_entities: List[Dict] = [] - for ent in response[0][command_str]["entities"]: - if ent["id"] in ids_of_interest: - new_entities.append(ent) - if len(new_entities) == k_neighbors: - break - response[0][command_str]["entities"] = new_entities - response[0][command_str]["returned"] = len(new_entities) - if len(new_entities) < k_neighbors: - p_str = "Returned items < k_neighbors; Try increasing fetch_k" - print(p_str) # noqa: T201 - - if normalize_distance: - max_dist = 1.0 if max_dist in [0, np.inf] else max_dist - for ent_idx, ent in enumerate(response[0][command_str]["entities"]): - ent["_distance"] = ent["_distance"] / max_dist - response[0][command_str]["entities"][ent_idx]["_distance"] = ent[ - "_distance" - ] - - return response, response_array - - def encode_image(self, image_path: str) -> str: - with open(image_path, "rb") as f: - blob = f.read() - return base64.b64encode(blob).decode("utf-8") - - @classmethod - def from_documents( - cls: Type[VDMS], - documents: List[Document], - embedding: Optional[Embeddings] = None, - ids: Optional[List[str]] = None, - batch_size: int = DEFAULT_INSERT_BATCH_SIZE, - collection_name: str = DEFAULT_COLLECTION_NAME, # Add this line - **kwargs: Any, - ) -> VDMS: - """Create a VDMS vectorstore from a list of documents. - - Args: - collection_name (str): Name of the collection to create. - documents (List[Document]): List of documents to add to vectorstore. - embedding (Embeddings): Embedding function. Defaults to None. - ids (Optional[List[str]]): List of document IDs. Defaults to None. - batch_size (int): Number of concurrent requests to send to the server. - - Returns: - VDMS: VDMS vectorstore. - """ - client: vdms.vdms = kwargs["client"] - - return cls.from_texts( - client=client, - texts=[doc.page_content for doc in documents], - metadatas=[doc.metadata for doc in documents], - embedding=embedding, - ids=ids, - batch_size=batch_size, - collection_name=collection_name, - # **kwargs, - ) - - @classmethod - def from_texts( - cls: Type[VDMS], - texts: List[str], - embedding: Optional[Embeddings] = None, - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - batch_size: int = DEFAULT_INSERT_BATCH_SIZE, - collection_name: str = DEFAULT_COLLECTION_NAME, - **kwargs: Any, - ) -> VDMS: - """Create a VDMS vectorstore from a raw documents. - - Args: - texts (List[str]): List of texts to add to the collection. - embedding (Embeddings): Embedding function. Defaults to None. - metadatas (Optional[List[dict]]): List of metadatas. Defaults to None. - ids (Optional[List[str]]): List of document IDs. Defaults to None. - batch_size (int): Number of concurrent requests to send to the server. - collection_name (str): Name of the collection to create. - - Returns: - VDMS: VDMS vectorstore. - """ - client: vdms.vdms = kwargs["client"] - vdms_collection = cls( - collection_name=collection_name, - embedding=embedding, - client=client, - # **kwargs, - ) - if ids is None: - ids = [str(uuid.uuid4()) for _ in texts] - vdms_collection.add_texts( - texts=texts, - metadatas=metadatas, - ids=ids, - batch_size=batch_size, # **kwargs - ) - return vdms_collection - - def get( - self, - collection_name: str, - constraints: Optional[Dict] = None, - limit: Optional[int] = None, - include: List[str] = ["metadata"], - ) -> Tuple[Any, Any]: - """Gets the collection. - Get embeddings and their associated data from the data store. - If no constraints provided returns all embeddings up to limit. - - Args: - constraints: A dict used to filter results by. - E.g. `{"color" : ["==", "red"], "price": [">", 4.00]}`. Optional. - limit: The number of documents to return. Optional. - include: A list of what to include in the results. - Can contain `"embeddings"`, `"metadatas"`, `"documents"`. - Ids are always included. - Defaults to `["metadatas", "documents"]`. Optional. - """ - all_queries: List[Any] = [] - all_blobs: List[Any] = [] - - results: Dict[str, Any] = {"count": ""} - - if limit is not None: - results["limit"] = limit - - # Include metadata - if "metadata" in include: - collection_properties = self.__get_properties(collection_name) - results["list"] = collection_properties - - # Include embedding - if "embeddings" in include: - results["blob"] = True - - query = _add_descriptor( - "FindDescriptor", - collection_name, - k_neighbors=None, - constraints=constraints, - results=results, - ) - - all_queries.append(query) - - response, response_array = self.__run_vdms_query(all_queries, all_blobs) - return response, response_array - - def max_marginal_relevance_search( - self, - query: str, - k: int = DEFAULT_K, - fetch_k: int = DEFAULT_FETCH_K, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, List]] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query (str): Query to look up. Text or path for image or video. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents selected by maximal marginal relevance. - """ - if self.embedding is None: - raise ValueError( - "For MMR search, you must specify an embedding function oncreation." - ) - - # embedding_vector: List[float] = self._embed_query(query) - embedding_vector: List[float] - if not os.path.isfile(query) and hasattr(self.embedding, "embed_query"): - embedding_vector = self._embed_query(query) - elif os.path.isfile(query) and hasattr(self.embedding, "embed_image"): - embedding_vector = self._embed_image(uris=[query])[0] - elif os.path.isfile(query) and hasattr(self.embedding, "embed_video"): - embedding_vector = self._embed_video(paths=[query])[0] - else: - error_msg = f"Could not generate embedding for query '{query}'." - error_msg += "If using path for image or video, verify embedding model " - error_msg += "has callable functions 'embed_image' or 'embed_video'." - raise ValueError(error_msg) - - docs = self.max_marginal_relevance_search_by_vector( - embedding_vector, - k, - fetch_k, - lambda_mult=lambda_mult, - filter=filter, - ) - return docs - - def max_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = DEFAULT_K, - fetch_k: int = DEFAULT_FETCH_K, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, List]] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents selected by maximal marginal relevance. - """ - results = self.query_collection_embeddings( - query_embeddings=[embedding], - n_results=fetch_k, - filter=filter, - include=["metadatas", "documents", "distances", "embeddings"], - ) - - if len(results[0][1]) == 0: - # No results returned - return [] - else: - embedding_list = [ - list(_bytes2embedding(result)) for result in results[0][1] - ] - - mmr_selected = maximal_marginal_relevance( - np.array(embedding, dtype=np.float32), - embedding_list, - k=k, - lambda_mult=lambda_mult, - ) - - candidates = _results_to_docs(results) - - selected_results = [ - r for i, r in enumerate(candidates) if i in mmr_selected - ] - return selected_results - - def max_marginal_relevance_search_with_score( - self, - query: str, - k: int = DEFAULT_K, - fetch_k: int = DEFAULT_FETCH_K, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, List]] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs selected using the maximal marginal relevance. - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query (str): Query to look up. Text or path for image or video. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents selected by maximal marginal relevance. - """ - if self.embedding is None: - raise ValueError( - "For MMR search, you must specify an embedding function oncreation." - ) - - if not os.path.isfile(query) and hasattr(self.embedding, "embed_query"): - embedding = self._embed_query(query) - elif os.path.isfile(query) and hasattr(self.embedding, "embed_image"): - embedding = self._embed_image(uris=[query])[0] - elif os.path.isfile(query) and hasattr(self.embedding, "embed_video"): - embedding = self._embed_video(paths=[query])[0] - else: - error_msg = f"Could not generate embedding for query '{query}'." - error_msg += "If using path for image or video, verify embedding model " - error_msg += "has callable functions 'embed_image' or 'embed_video'." - raise ValueError(error_msg) - - docs = self.max_marginal_relevance_search_with_score_by_vector( - embedding, - k, - fetch_k, - lambda_mult=lambda_mult, - filter=filter, - ) - return docs - - def max_marginal_relevance_search_with_score_by_vector( - self, - embedding: List[float], - k: int = DEFAULT_K, - fetch_k: int = DEFAULT_FETCH_K, - lambda_mult: float = 0.5, - filter: Optional[Dict[str, List]] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs selected using the maximal marginal relevance. - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents selected by maximal marginal relevance. - """ - results = self.query_collection_embeddings( - query_embeddings=[embedding], - n_results=fetch_k, - filter=filter, - include=["metadatas", "documents", "distances", "embeddings"], - ) - - if len(results[0][1]) == 0: - # No results returned - return [] - else: - embedding_list = [ - list(_bytes2embedding(result)) for result in results[0][1] - ] - - mmr_selected = maximal_marginal_relevance( - np.array(embedding, dtype=np.float32), - embedding_list, - k=k, - lambda_mult=lambda_mult, - ) - - candidates = _results_to_docs_and_scores(results) - - selected_results = [ - (r, s) for i, (r, s) in enumerate(candidates) if i in mmr_selected - ] - return selected_results - - def query_collection_embeddings( - self, - query_embeddings: Optional[List[List[float]]] = None, - collection_name: Optional[str] = None, - n_results: int = DEFAULT_K, - fetch_k: int = DEFAULT_FETCH_K, - filter: Union[None, Dict[str, Any]] = None, - results: Union[None, Dict[str, Any]] = None, - normalize_distance: bool = False, - **kwargs: Any, - ) -> List[Tuple[Dict[str, Any], List]]: - all_responses: List[Any] = [] - - if collection_name is None: - collection_name = self._collection_name - - if query_embeddings is None: - return all_responses - - include = kwargs.get("include", ["metadatas"]) - if results is None and "metadatas" in include: - results = { - "list": self.collection_properties, - "blob": "embeddings" in include, - } - - for qemb in query_embeddings: - response, response_array = self.get_descriptor_response( - "FindDescriptor", - collection_name, - k_neighbors=n_results, - fetch_k=fetch_k, - constraints=filter, - results=results, - normalize_distance=normalize_distance, - query_embedding=qemb, - ) - all_responses.append([response, response_array]) - - return all_responses - - def similarity_search( - self, - query: str, - k: int = DEFAULT_K, - fetch_k: int = DEFAULT_FETCH_K, - filter: Optional[Dict[str, List]] = None, - **kwargs: Any, - ) -> List[Document]: - """Run similarity search with VDMS. - - Args: - query (str): Query to look up. Text or path for image or video. - k (int): Number of results to return. Defaults to 3. - fetch_k (int): Number of candidates to fetch for knn (>= k). - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List[Document]: List of documents most similar to the query text. - """ - docs_and_scores = self.similarity_search_with_score( - query, k=k, fetch_k=fetch_k, filter=filter, **kwargs - ) - return [doc for doc, _ in docs_and_scores] - - def similarity_search_by_vector( - self, - embedding: List[float], - k: int = DEFAULT_K, - fetch_k: int = DEFAULT_FETCH_K, - filter: Optional[Dict[str, List]] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs most similar to embedding vector. - Args: - embedding (List[float]): Embedding to look up documents similar to. - k (int): Number of Documents to return. Defaults to 3. - fetch_k (int): Number of candidates to fetch for knn (>= k). - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - Returns: - List of Documents most similar to the query vector. - """ - results = self.query_collection_embeddings( - query_embeddings=[embedding], - n_results=k, - fetch_k=fetch_k, - filter=filter, - **kwargs, - ) - - return _results_to_docs(results) - - def similarity_search_with_score( - self, - query: str, - k: int = DEFAULT_K, - fetch_k: int = DEFAULT_FETCH_K, - filter: Optional[Dict[str, List]] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Run similarity search with VDMS with distance. - - Args: - query (str): Query to look up. Text or path for image or video. - k (int): Number of results to return. Defaults to 3. - fetch_k (int): Number of candidates to fetch for knn (>= k). - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List[Tuple[Document, float]]: List of documents most similar to - the query text and cosine distance in float for each. - Lower score represents more similarity. - """ - if self.embedding is None: - raise ValueError("Must provide embedding function") - else: - if not os.path.isfile(query) and hasattr(self.embedding, "embed_query"): - query_embedding: List[float] = self._embed_query(query) - elif os.path.isfile(query) and hasattr(self.embedding, "embed_image"): - query_embedding = self._embed_image(uris=[query])[0] - elif os.path.isfile(query) and hasattr(self.embedding, "embed_video"): - query_embedding = self._embed_video(paths=[query])[0] - else: - error_msg = f"Could not generate embedding for query '{query}'." - error_msg += "If using path for image or video, verify embedding model " - error_msg += "has callable functions 'embed_image' or 'embed_video'." - raise ValueError(error_msg) - - results = self.query_collection_embeddings( - query_embeddings=[query_embedding], - n_results=k, - fetch_k=fetch_k, - filter=filter, - **kwargs, - ) - - return _results_to_docs_and_scores(results) - - def similarity_search_with_score_by_vector( - self, - embedding: List[float], - k: int = DEFAULT_K, - fetch_k: int = DEFAULT_FETCH_K, - filter: Optional[Dict[str, List]] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """ - Return docs most similar to embedding vector and similarity score. - - Args: - embedding (List[float]): Embedding to look up documents similar to. - k (int): Number of Documents to return. Defaults to 3. - fetch_k (int): Number of candidates to fetch for knn (>= k). - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List[Tuple[Document, float]]: List of documents most similar to - the query text. Lower score represents more similarity. - """ - - # kwargs["normalize_distance"] = True - - results = self.query_collection_embeddings( - query_embeddings=[embedding], - n_results=k, - fetch_k=fetch_k, - filter=filter, - **kwargs, - ) - return _results_to_docs_and_scores(results) - - def update_document( - self, collection_name: str, document_id: str, document: Document - ) -> None: - """Update a document in the collection. - - Args: - document_id (str): ID of the document to update. - document (Document): Document to update. - """ - return self.update_documents(collection_name, [document_id], [document]) - - def update_documents( - self, collection_name: str, ids: List[str], documents: List[Document] - ) -> None: - """Update a document in the collection. - - Args: - ids (List[str]): List of ids of the document to update. - documents (List[Document]): List of documents to update. - """ - text = [document.page_content for document in documents] - metadata = [ - _validate_vdms_properties(document.metadata) for document in documents - ] - embeddings = self._embed_documents(text) - - self.__update( - collection_name, - ids, - metadatas=metadata, - embeddings=embeddings, - documents=text, - ) - - -# VDMS UTILITY - - -def _add_descriptor( - command_str: str, - setname: str, - label: Optional[str] = None, - ref: Optional[int] = None, - props: Optional[dict] = None, - link: Optional[dict] = None, - k_neighbors: Optional[int] = None, - constraints: Optional[dict] = None, - results: Optional[dict] = None, -) -> Dict[str, Dict[str, Any]]: - entity: Dict[str, Any] = {"set": setname} - - if "Add" in command_str and label: - entity["label"] = label - - if ref is not None: - entity["_ref"] = ref - - if props not in INVALID_METADATA_VALUE: - entity["properties"] = props - - if "Add" in command_str and link is not None: - entity["link"] = link - - if "Find" in command_str and k_neighbors is not None: - entity["k_neighbors"] = int(k_neighbors) - - if "Find" in command_str and constraints not in INVALID_METADATA_VALUE: - entity["constraints"] = constraints - - if "Find" in command_str and results not in INVALID_METADATA_VALUE: - entity["results"] = results - - query = {command_str: entity} - return query - - -def _add_descriptorset( - command_str: str, - name: str, - num_dims: Optional[int] = None, - engine: Optional[str] = None, - metric: Optional[str] = None, - ref: Optional[int] = None, - props: Optional[Dict] = None, - link: Optional[Dict] = None, - storeIndex: bool = False, - constraints: Optional[Dict] = None, - results: Optional[Dict] = None, -) -> Dict[str, Any]: - if command_str == "AddDescriptorSet" and all( - var is not None for var in [name, num_dims] - ): - entity: Dict[str, Any] = { - "name": name, - "dimensions": num_dims, - } - - if engine is not None: - entity["engine"] = engine - - if metric is not None: - entity["metric"] = metric - - if ref is not None: - entity["_ref"] = ref - - if props not in [None, {}]: - entity["properties"] = props - - if link is not None: - entity["link"] = link - - elif command_str == "FindDescriptorSet": - entity = {"set": name} - - if storeIndex: - entity["storeIndex"] = storeIndex - - if constraints not in [None, {}]: - entity["constraints"] = constraints - - if results is not None: - entity["results"] = results - - else: - raise ValueError(f"Unknown command: {command_str}") - - query = {command_str: entity} - return query - - -def _add_entity_with_blob( - collection_name: str, all_properties: List -) -> Tuple[Dict[str, Any], bytes]: - all_properties_str = ",".join(all_properties) if len(all_properties) > 0 else "" - - querytype = "AddEntity" - entity: Dict[str, Any] = {} - entity["class"] = "properties" - entity["blob"] = True # New - - props: Dict[str, Any] = {"name": collection_name} - props["type"] = "queryable properties" - props["content"] = all_properties_str - entity["properties"] = props - - byte_data = _str2bytes(all_properties_str) - - query: Dict[str, Any] = {} - query[querytype] = entity - return query, byte_data - - -def _build_property_query( - collection_name: str, - command_type: str = "find", - all_properties: List = [], - ref: Optional[int] = None, -) -> Tuple[Any, Any]: - all_queries: List[Any] = [] - blob_arr: List[Any] = [] - - choices = ["find", "add", "update"] - if command_type.lower() not in choices: - raise ValueError("[!] Invalid type. Choices are : {}".format(",".join(choices))) - - if command_type.lower() == "find": - query = _find_property_entity(collection_name, unique_entity=True) - all_queries.append(query) - - elif command_type.lower() == "add": - query, byte_data = _add_entity_with_blob(collection_name, all_properties) - all_queries.append(query) - blob_arr.append(byte_data) - - elif command_type.lower() == "update": - # Find & Delete - query = _find_property_entity(collection_name, deletion=True) - all_queries.append(query) - - # Add - query, byte_data = _add_entity_with_blob(collection_name, all_properties) - all_queries.append(query) - blob_arr.append(byte_data) - - return all_queries, blob_arr - - -def _bytes2embedding(blob: bytes) -> Any: - emb = np.frombuffer(blob, dtype="float32") - return emb - - -def _bytes2str(in_bytes: bytes) -> str: - return in_bytes.decode() - - -def _get_cmds_from_query(all_queries: list) -> List[str]: - return list(set([k for q in all_queries for k in q.keys()])) - - -def _check_valid_response(all_queries: List[dict], response: Any) -> bool: - cmd_list = _get_cmds_from_query(all_queries) - valid_res = isinstance(response, list) and any( - cmd in response[0] - and "returned" in response[0][cmd] - and response[0][cmd]["returned"] > 0 - for cmd in cmd_list - ) - return valid_res - - -def _check_descriptor_exists_by_id( - client: vdms.vdms, - setname: str, - id: str, -) -> Tuple[bool, Any]: - constraints = {"id": ["==", id]} - findDescriptor = _add_descriptor( - "FindDescriptor", - setname, - constraints=constraints, - results={"list": ["id"], "count": ""}, - ) - all_queries = [findDescriptor] - res, _ = client.query(all_queries) - - valid_res = _check_valid_response(all_queries, res) - return valid_res, findDescriptor - - -def embedding2bytes(embedding: Union[List[float], None]) -> Union[bytes, None]: - """Convert embedding to bytes.""" - - blob = None - if embedding is not None: - emb = np.array(embedding, dtype="float32") - blob = emb.tobytes() - return blob - - -def _find_property_entity( - collection_name: str, - unique_entity: Optional[bool] = False, - deletion: Optional[bool] = False, -) -> Dict[str, Dict[str, Any]]: - querytype = "FindEntity" - entity: Dict[str, Any] = {} - entity["class"] = "properties" - if unique_entity: - entity["unique"] = unique_entity - - results: Dict[str, Any] = {} - results["blob"] = True - results["count"] = "" - results["list"] = ["content"] - entity["results"] = results - - constraints: Dict[str, Any] = {} - if deletion: - constraints["_deletion"] = ["==", 1] - constraints["name"] = ["==", collection_name] - entity["constraints"] = constraints - - query: Dict[str, Any] = {} - query[querytype] = entity - return query - - -def _str2bytes(in_str: str) -> bytes: - return str.encode(in_str) - - -def _validate_vdms_properties(metadata: Dict[str, Any]) -> Dict: - new_metadata: Dict[str, Any] = {} - for key, value in metadata.items(): - if not isinstance(value, list): - new_metadata[str(key)] = value - return new_metadata diff --git a/libs/community/langchain_community/vectorstores/weaviate.py b/libs/community/langchain_community/vectorstores/weaviate.py deleted file mode 100644 index 85989fe57..000000000 --- a/libs/community/langchain_community/vectorstores/weaviate.py +++ /dev/null @@ -1,534 +0,0 @@ -from __future__ import annotations - -import datetime -import os -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - List, - Optional, - Tuple, -) -from uuid import uuid4 - -import numpy as np -from langchain_core._api import deprecated -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from langchain_core.vectorstores import VectorStore - -from langchain_community.vectorstores.utils import maximal_marginal_relevance - -if TYPE_CHECKING: - import weaviate - - -def _default_schema(index_name: str, text_key: str) -> Dict: - return { - "class": index_name, - "properties": [ - { - "name": text_key, - "dataType": ["text"], - } - ], - } - - -def _create_weaviate_client( - url: Optional[str] = None, - api_key: Optional[str] = None, - **kwargs: Any, -) -> weaviate.Client: - try: - import weaviate - except ImportError: - raise ImportError( - "Could not import weaviate python package. " - "Please install it with `pip install weaviate-client`" - ) - url = url or os.environ.get("WEAVIATE_URL") - api_key = api_key or os.environ.get("WEAVIATE_API_KEY") - auth = weaviate.auth.AuthApiKey(api_key=api_key) if api_key else None - return weaviate.Client(url=url, auth_client_secret=auth, **kwargs) - - -def _default_score_normalizer(val: float) -> float: - return 1 - 1 / (1 + np.exp(val)) - - -def _json_serializable(value: Any) -> Any: - if isinstance(value, datetime.datetime): - return value.isoformat() - return value - - -@deprecated( - since="0.3.18", - removal="1.0", - alternative_import="langchain_weaviate.WeaviateVectorStore", -) -class Weaviate(VectorStore): - """`Weaviate` vector store. - - To use, you should have the ``weaviate-client`` python package installed. - - Example: - .. code-block:: python - - import weaviate - from langchain_community.vectorstores import Weaviate - - client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...) - weaviate = Weaviate(client, index_name, text_key) - - """ - - def __init__( - self, - client: Any, - index_name: str, - text_key: str, - embedding: Optional[Embeddings] = None, - attributes: Optional[List[str]] = None, - relevance_score_fn: Optional[ - Callable[[float], float] - ] = _default_score_normalizer, - by_text: bool = True, - ): - """Initialize with Weaviate client.""" - try: - import weaviate - except ImportError: - raise ImportError( - "Could not import weaviate python package. " - "Please install it with `pip install weaviate-client`." - ) - if not isinstance(client, weaviate.Client): - raise ValueError( - f"client should be an instance of weaviate.Client, got {type(client)}" - ) - self._client = client - self._index_name = index_name - self._embedding = embedding - self._text_key = text_key - self._query_attrs = [self._text_key] - self.relevance_score_fn = relevance_score_fn - self._by_text = by_text - if attributes is not None: - self._query_attrs.extend(attributes) - - @property - def embeddings(self) -> Optional[Embeddings]: - return self._embedding - - def _select_relevance_score_fn(self) -> Callable[[float], float]: - return ( - self.relevance_score_fn - if self.relevance_score_fn - else _default_score_normalizer - ) - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ) -> List[str]: - """Upload texts with metadata (properties) to Weaviate.""" - from weaviate.util import get_valid_uuid - - ids = [] - embeddings: Optional[List[List[float]]] = None - if self._embedding: - if not isinstance(texts, list): - texts = list(texts) - embeddings = self._embedding.embed_documents(texts) - - with self._client.batch as batch: - for i, text in enumerate(texts): - data_properties = {self._text_key: text} - if metadatas is not None: - for key, val in metadatas[i].items(): - data_properties[key] = _json_serializable(val) - - # Allow for ids (consistent w/ other methods) - # # Or uuids (backwards compatible w/ existing arg) - # If the UUID of one of the objects already exists - # then the existing object will be replaced by the new object. - _id = get_valid_uuid(uuid4()) - if "uuids" in kwargs: - _id = kwargs["uuids"][i] - elif "ids" in kwargs: - _id = kwargs["ids"][i] - - batch.add_data_object( - data_object=data_properties, - class_name=self._index_name, - uuid=_id, - vector=embeddings[i] if embeddings else None, - tenant=kwargs.get("tenant"), - ) - ids.append(_id) - return ids - - def similarity_search( - self, query: str, k: int = 4, **kwargs: Any - ) -> List[Document]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - - Returns: - List of Documents most similar to the query. - """ - if self._by_text: - return self.similarity_search_by_text(query, k, **kwargs) - else: - if self._embedding is None: - raise ValueError( - "_embedding cannot be None for similarity_search when " - "_by_text=False" - ) - embedding = self._embedding.embed_query(query) - return self.similarity_search_by_vector(embedding, k, **kwargs) - - def similarity_search_by_text( - self, query: str, k: int = 4, **kwargs: Any - ) -> List[Document]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - - Returns: - List of Documents most similar to the query. - """ - content: Dict[str, Any] = {"concepts": [query]} - if kwargs.get("search_distance"): - content["certainty"] = kwargs.get("search_distance") - query_obj = self._client.query.get(self._index_name, self._query_attrs) - if kwargs.get("where_filter"): - query_obj = query_obj.with_where(kwargs.get("where_filter")) - if kwargs.get("tenant"): - query_obj = query_obj.with_tenant(kwargs.get("tenant")) - if kwargs.get("additional"): - query_obj = query_obj.with_additional(kwargs.get("additional")) - result = query_obj.with_near_text(content).with_limit(k).do() - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - docs = [] - for res in result["data"]["Get"][self._index_name]: - text = res.pop(self._text_key) - docs.append(Document(page_content=text, metadata=res)) - return docs - - def similarity_search_by_vector( - self, embedding: List[float], k: int = 4, **kwargs: Any - ) -> List[Document]: - """Look up similar documents by embedding vector in Weaviate.""" - vector = {"vector": embedding} - query_obj = self._client.query.get(self._index_name, self._query_attrs) - if kwargs.get("where_filter"): - query_obj = query_obj.with_where(kwargs.get("where_filter")) - if kwargs.get("tenant"): - query_obj = query_obj.with_tenant(kwargs.get("tenant")) - if kwargs.get("additional"): - query_obj = query_obj.with_additional(kwargs.get("additional")) - result = query_obj.with_near_vector(vector).with_limit(k).do() - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - docs = [] - for res in result["data"]["Get"][self._index_name]: - text = res.pop(self._text_key) - docs.append(Document(page_content=text, metadata=res)) - return docs - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - - Returns: - List of Documents selected by maximal marginal relevance. - """ - if self._embedding is not None: - embedding = self._embedding.embed_query(query) - else: - raise ValueError( - "max_marginal_relevance_search requires a suitable Embeddings object" - ) - - return self.max_marginal_relevance_search_by_vector( - embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs - ) - - def max_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - - Returns: - List of Documents selected by maximal marginal relevance. - """ - vector = {"vector": embedding} - query_obj = self._client.query.get(self._index_name, self._query_attrs) - if kwargs.get("where_filter"): - query_obj = query_obj.with_where(kwargs.get("where_filter")) - if kwargs.get("tenant"): - query_obj = query_obj.with_tenant(kwargs.get("tenant")) - results = ( - query_obj.with_additional("vector") - .with_near_vector(vector) - .with_limit(fetch_k) - .do() - ) - - payload = results["data"]["Get"][self._index_name] - embeddings = [result["_additional"]["vector"] for result in payload] - mmr_selected = maximal_marginal_relevance( - np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult - ) - - docs = [] - for idx in mmr_selected: - text = payload[idx].pop(self._text_key) - payload[idx].pop("_additional") - meta = payload[idx] - docs.append(Document(page_content=text, metadata=meta)) - return docs - - def similarity_search_with_score( - self, query: str, k: int = 4, **kwargs: Any - ) -> List[Tuple[Document, float]]: - """ - Return list of documents most similar to the query - text and cosine distance in float for each. - Lower score represents more similarity. - """ - if self._embedding is None: - raise ValueError( - "_embedding cannot be None for similarity_search_with_score" - ) - content: Dict[str, Any] = {"concepts": [query]} - if kwargs.get("search_distance"): - content["certainty"] = kwargs.get("search_distance") - query_obj = self._client.query.get(self._index_name, self._query_attrs) - if kwargs.get("where_filter"): - query_obj = query_obj.with_where(kwargs.get("where_filter")) - if kwargs.get("tenant"): - query_obj = query_obj.with_tenant(kwargs.get("tenant")) - - embedded_query = self._embedding.embed_query(query) - if not self._by_text: - vector = {"vector": embedded_query} - result = ( - query_obj.with_near_vector(vector) - .with_limit(k) - .with_additional("vector") - .do() - ) - else: - result = ( - query_obj.with_near_text(content) - .with_limit(k) - .with_additional("vector") - .do() - ) - - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - - docs_and_scores = [] - for res in result["data"]["Get"][self._index_name]: - text = res.pop(self._text_key) - score = np.dot(res["_additional"]["vector"], embedded_query) - docs_and_scores.append((Document(page_content=text, metadata=res), score)) - return docs_and_scores - - @classmethod - def from_texts( - cls, - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - *, - client: Optional[weaviate.Client] = None, - weaviate_url: Optional[str] = None, - weaviate_api_key: Optional[str] = None, - batch_size: Optional[int] = None, - index_name: Optional[str] = None, - text_key: str = "text", - by_text: bool = False, - relevance_score_fn: Optional[ - Callable[[float], float] - ] = _default_score_normalizer, - **kwargs: Any, - ) -> Weaviate: - """Construct Weaviate wrapper from raw documents. - - This is a user-friendly interface that: - 1. Embeds documents. - 2. Creates a new index for the embeddings in the Weaviate instance. - 3. Adds the documents to the newly created Weaviate index. - - This is intended to be a quick way to get started. - - Args: - texts: Texts to add to vector store. - embedding: Text embedding model to use. - metadatas: Metadata associated with each text. - client: weaviate.Client to use. - weaviate_url: The Weaviate URL. If using Weaviate Cloud Services get it - from the ``Details`` tab. Can be passed in as a named param or by - setting the environment variable ``WEAVIATE_URL``. Should not be - specified if client is provided. - weaviate_api_key: The Weaviate API key. If enabled and using Weaviate Cloud - Services, get it from ``Details`` tab. Can be passed in as a named param - or by setting the environment variable ``WEAVIATE_API_KEY``. Should - not be specified if client is provided. - batch_size: Size of batch operations. - index_name: Index name. - text_key: Key to use for uploading/retrieving text to/from vectorstore. - by_text: Whether to search by text or by embedding. - relevance_score_fn: Function for converting whatever distance function the - vector store uses to a relevance score, which is a normalized similarity - score (0 means dissimilar, 1 means similar). - kwargs: Additional named parameters to pass to ``Weaviate.__init__()``. - - Example: - .. code-block:: python - - from langchain_community.embeddings import OpenAIEmbeddings - from langchain_community.vectorstores import Weaviate - - embeddings = OpenAIEmbeddings() - weaviate = Weaviate.from_texts( - texts, - embeddings, - weaviate_url="http://localhost:8080" - ) - """ - - try: - from weaviate.util import get_valid_uuid - except ImportError as e: - raise ImportError( - "Could not import weaviate python package. " - "Please install it with `pip install weaviate-client`" - ) from e - - client = client or _create_weaviate_client( - url=weaviate_url, - api_key=weaviate_api_key, - ) - if batch_size: - client.batch.configure(batch_size=batch_size) - - index_name = index_name or f"LangChain_{uuid4().hex}" - schema = _default_schema(index_name, text_key) - # check whether the index already exists - if not client.schema.exists(index_name): - client.schema.create_class(schema) - - embeddings = embedding.embed_documents(texts) if embedding else None - attributes = list(metadatas[0].keys()) if metadatas else None - - # If the UUID of one of the objects already exists - # then the existing object will be replaced by the new object. - if "uuids" in kwargs: - uuids = kwargs.pop("uuids") - else: - uuids = [get_valid_uuid(uuid4()) for _ in range(len(texts))] - - with client.batch as batch: - for i, text in enumerate(texts): - data_properties = { - text_key: text, - } - if metadatas is not None: - for key in metadatas[i].keys(): - data_properties[key] = metadatas[i][key] - - _id = uuids[i] - - # if an embedding strategy is not provided, we let - # weaviate create the embedding. Note that this will only - # work if weaviate has been installed with a vectorizer module - # like text2vec-contextionary for example - params = { - "uuid": _id, - "data_object": data_properties, - "class_name": index_name, - } - if embeddings is not None: - params["vector"] = embeddings[i] - - batch.add_data_object(**params) - - batch.flush() - - return cls( - client, - index_name, - text_key, - embedding=embedding, - attributes=attributes, - relevance_score_fn=relevance_score_fn, - by_text=by_text, - **kwargs, - ) - - def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None: - """Delete by vector IDs. - - Args: - ids: List of ids to delete. - """ - - if ids is None: - raise ValueError("No ids provided to delete.") - - # TODO: Check if this can be done in bulk - for id in ids: - self._client.data_object.delete(uuid=id) diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index 3931459bf..91a0eb688 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -63,7 +63,11 @@ lint = [ "cffi<1.17.1; python_version < \"3.10\"", "cffi; python_version >= \"3.10\"", ] -dev = ["jupyter>=1.0.0,<2.0.0", "setuptools>=67.6.1,<68.0.0", "langchain-core"] +dev = [ + "jupyter>=1.0.0,<2.0.0", + "setuptools>=67.6.1,<68.0.0", + "langchain-core", +] typing = [ "mypy>=1.17.1,<2.0.0", "types-pyyaml>=6.0.12.2,<7.0.0.0", diff --git a/libs/community/scripts/check_pydantic.sh b/libs/community/scripts/check_pydantic.sh index c3ef67dc7..bcf6a8b56 100755 --- a/libs/community/scripts/check_pydantic.sh +++ b/libs/community/scripts/check_pydantic.sh @@ -20,7 +20,7 @@ count=$(git grep -E '(@root_validator)|(@validator)|(@field_validator)|(@pre_ini # PRs that increase the current count will not be accepted. # PRs that decrease update the code in the repository # and allow decreasing the count of are welcome! -current_count=123 +current_count=118 if [ "$count" -gt "$current_count" ]; then echo "The PR seems to be introducing new usage of @root_validator and/or @field_validator." diff --git a/libs/community/tests/integration_tests/agent/test_ainetwork_agent.py b/libs/community/tests/integration_tests/agent/test_ainetwork_agent.py index cc379c564..d4b7efdc2 100644 --- a/libs/community/tests/integration_tests/agent/test_ainetwork_agent.py +++ b/libs/community/tests/integration_tests/agent/test_ainetwork_agent.py @@ -11,7 +11,7 @@ from langchain_classic.agents import AgentType, initialize_agent from langchain_community.agent_toolkits.ainetwork.toolkit import AINetworkToolkit -from langchain_community.chat_models import ChatOpenAI +from langchain_community.chat_models.openai import ChatOpenAI from langchain_community.tools.ainetwork.utils import authenticate diff --git a/libs/community/tests/integration_tests/agent/test_powerbi_agent.py b/libs/community/tests/integration_tests/agent/test_powerbi_agent.py index 7f16ce8d2..7079caa69 100644 --- a/libs/community/tests/integration_tests/agent/test_powerbi_agent.py +++ b/libs/community/tests/integration_tests/agent/test_powerbi_agent.py @@ -2,7 +2,7 @@ from langchain_core.utils import get_from_env from langchain_community.agent_toolkits import PowerBIToolkit, create_pbi_agent -from langchain_community.chat_models import ChatOpenAI +from langchain_community.chat_models.openai import ChatOpenAI from langchain_community.utilities.powerbi import PowerBIDataset @@ -26,11 +26,9 @@ def test_daxquery() -> None: NUM_ROWS = get_from_env("", "POWERBI_NUMROWS") fast_llm = ChatOpenAI( - temperature=0.5, max_tokens=1000, model_name="gpt-3.5-turbo", verbose=True - ) # type: ignore[call-arg] - smart_llm = ChatOpenAI( - temperature=0, max_tokens=100, model_name="gpt-4", verbose=True - ) # type: ignore[call-arg] + temperature=0.5, max_tokens=1000, model="gpt-3.5-turbo", verbose=True + ) + smart_llm = ChatOpenAI(temperature=0, max_tokens=100, model="gpt-4", verbose=True) toolkit = PowerBIToolkit( powerbi=PowerBIDataset( diff --git a/libs/community/tests/integration_tests/cache/test_astradb.py b/libs/community/tests/integration_tests/cache/test_astradb.py deleted file mode 100644 index 5ae984b2c..000000000 --- a/libs/community/tests/integration_tests/cache/test_astradb.py +++ /dev/null @@ -1,159 +0,0 @@ -""" -Test AstraDB caches. Requires an Astra DB vector instance. - -Required to run this test: - - a recent `astrapy` Python package available - - an Astra DB instance; - - the two environment variables set: - export ASTRA_DB_API_ENDPOINT="https://-us-east1.apps.astra.datastax.com" - export ASTRA_DB_APPLICATION_TOKEN="AstraCS:........." - - optionally this as well (otherwise defaults are used): - export ASTRA_DB_KEYSPACE="my_keyspace" -""" - -import os -from typing import AsyncIterator, Iterator - -import pytest -from langchain_classic.globals import get_llm_cache, set_llm_cache -from langchain_core.caches import BaseCache -from langchain_core.language_models import LLM -from langchain_core.outputs import Generation, LLMResult - -from langchain_community.cache import AstraDBCache, AstraDBSemanticCache -from langchain_community.utilities.astradb import SetupMode -from tests.integration_tests.cache.fake_embeddings import FakeEmbeddings -from tests.unit_tests.llms.fake_llm import FakeLLM - - -def _has_env_vars() -> bool: - return all( - [ - "ASTRA_DB_APPLICATION_TOKEN" in os.environ, - "ASTRA_DB_API_ENDPOINT" in os.environ, - ] - ) - - -@pytest.fixture(scope="module") -def astradb_cache() -> Iterator[AstraDBCache]: - cache = AstraDBCache( - collection_name="lc_integration_test_cache", - token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], - api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], - namespace=os.environ.get("ASTRA_DB_KEYSPACE"), - ) - yield cache - cache.collection.astra_db.delete_collection("lc_integration_test_cache") - - -@pytest.fixture -async def async_astradb_cache() -> AsyncIterator[AstraDBCache]: - cache = AstraDBCache( - collection_name="lc_integration_test_cache_async", - token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], - api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], - namespace=os.environ.get("ASTRA_DB_KEYSPACE"), - setup_mode=SetupMode.ASYNC, - ) - yield cache - await cache.async_collection.astra_db.delete_collection( - "lc_integration_test_cache_async" - ) - - -@pytest.fixture(scope="module") -def astradb_semantic_cache() -> Iterator[AstraDBSemanticCache]: - fake_embe = FakeEmbeddings() - sem_cache = AstraDBSemanticCache( - collection_name="lc_integration_test_sem_cache", - token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], - api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], - namespace=os.environ.get("ASTRA_DB_KEYSPACE"), - embedding=fake_embe, - ) - yield sem_cache - sem_cache.collection.astra_db.delete_collection("lc_integration_test_sem_cache") - - -@pytest.fixture -async def async_astradb_semantic_cache() -> AsyncIterator[AstraDBSemanticCache]: - fake_embe = FakeEmbeddings() - sem_cache = AstraDBSemanticCache( - collection_name="lc_integration_test_sem_cache_async", - token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], - api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], - namespace=os.environ.get("ASTRA_DB_KEYSPACE"), - embedding=fake_embe, - setup_mode=SetupMode.ASYNC, - ) - yield sem_cache - sem_cache.collection.astra_db.delete_collection( - "lc_integration_test_sem_cache_async" - ) - - -@pytest.mark.requires("astrapy") -@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") -class TestAstraDBCaches: - def test_astradb_cache(self, astradb_cache: AstraDBCache) -> None: - self.do_cache_test(FakeLLM(), astradb_cache, "foo") - - async def test_astradb_cache_async(self, async_astradb_cache: AstraDBCache) -> None: - await self.ado_cache_test(FakeLLM(), async_astradb_cache, "foo") - - def test_astradb_semantic_cache( - self, astradb_semantic_cache: AstraDBSemanticCache - ) -> None: - llm = FakeLLM() - self.do_cache_test(llm, astradb_semantic_cache, "bar") - output = llm.generate(["bar"]) # 'fizz' is erased away now - assert output != LLMResult( - generations=[[Generation(text="fizz")]], - llm_output={}, - ) - astradb_semantic_cache.clear() - - async def test_astradb_semantic_cache_async( - self, async_astradb_semantic_cache: AstraDBSemanticCache - ) -> None: - llm = FakeLLM() - await self.ado_cache_test(llm, async_astradb_semantic_cache, "bar") - output = await llm.agenerate(["bar"]) # 'fizz' is erased away now - assert output != LLMResult( - generations=[[Generation(text="fizz")]], - llm_output={}, - ) - await async_astradb_semantic_cache.aclear() - - @staticmethod - def do_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None: - set_llm_cache(cache) - params = llm.dict() - params["stop"] = None - llm_string = str(sorted([(k, v) for k, v in params.items()])) - get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) # type: ignore[union-attr] - output = llm.generate([prompt]) - expected_output = LLMResult( - generations=[[Generation(text="fizz")]], - llm_output={}, - ) - assert output == expected_output - # clear the cache - cache.clear() - - @staticmethod - async def ado_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None: - set_llm_cache(cache) - params = llm.dict() - params["stop"] = None - llm_string = str(sorted([(k, v) for k, v in params.items()])) - await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")]) # type: ignore[union-attr] - output = await llm.agenerate([prompt]) - expected_output = LLMResult( - generations=[[Generation(text="fizz")]], - llm_output={}, - ) - assert output == expected_output - # clear the cache - await cache.aclear() diff --git a/libs/community/tests/integration_tests/callbacks/test_langchain_tracer.py b/libs/community/tests/integration_tests/callbacks/test_langchain_tracer.py index 59da98054..30939d876 100644 --- a/libs/community/tests/integration_tests/callbacks/test_langchain_tracer.py +++ b/libs/community/tests/integration_tests/callbacks/test_langchain_tracer.py @@ -8,7 +8,7 @@ from langchain_core.prompts import PromptTemplate from langchain_core.tracers.context import tracing_v2_enabled -from langchain_community.chat_models import ChatOpenAI +from langchain_community.chat_models.openai import ChatOpenAI from langchain_community.llms import OpenAI questions = [ diff --git a/libs/community/tests/integration_tests/chat_message_histories/test_neo4j.py b/libs/community/tests/integration_tests/chat_message_histories/test_neo4j.py deleted file mode 100644 index 5ab1af546..000000000 --- a/libs/community/tests/integration_tests/chat_message_histories/test_neo4j.py +++ /dev/null @@ -1,66 +0,0 @@ -import os - -from langchain_core.messages import AIMessage, HumanMessage - -from langchain_community.chat_message_histories import Neo4jChatMessageHistory -from langchain_community.graphs import Neo4jGraph - - -def test_add_messages() -> None: - """Basic testing: adding messages to the Neo4jChatMessageHistory.""" - assert os.environ.get("NEO4J_URI") is not None - assert os.environ.get("NEO4J_USERNAME") is not None - assert os.environ.get("NEO4J_PASSWORD") is not None - message_store = Neo4jChatMessageHistory("23334") - message_store.clear() - assert len(message_store.messages) == 0 - message_store.add_user_message("Hello! Language Chain!") - message_store.add_ai_message("Hi Guys!") - - # create another message store to check if the messages are stored correctly - message_store_another = Neo4jChatMessageHistory("46666") - message_store_another.clear() - assert len(message_store_another.messages) == 0 - message_store_another.add_user_message("Hello! Bot!") - message_store_another.add_ai_message("Hi there!") - message_store_another.add_user_message("How's this pr going?") - - # Now check if the messages are stored in the database correctly - assert len(message_store.messages) == 2 - assert isinstance(message_store.messages[0], HumanMessage) - assert isinstance(message_store.messages[1], AIMessage) - assert message_store.messages[0].content == "Hello! Language Chain!" - assert message_store.messages[1].content == "Hi Guys!" - - assert len(message_store_another.messages) == 3 - assert isinstance(message_store_another.messages[0], HumanMessage) - assert isinstance(message_store_another.messages[1], AIMessage) - assert isinstance(message_store_another.messages[2], HumanMessage) - assert message_store_another.messages[0].content == "Hello! Bot!" - assert message_store_another.messages[1].content == "Hi there!" - assert message_store_another.messages[2].content == "How's this pr going?" - - # Now clear the first history - message_store.clear() - assert len(message_store.messages) == 0 - assert len(message_store_another.messages) == 3 - message_store_another.clear() - assert len(message_store.messages) == 0 - assert len(message_store_another.messages) == 0 - - -def test_add_messages_graph_object() -> None: - """Basic testing: Passing driver through graph object.""" - assert os.environ.get("NEO4J_URI") is not None - assert os.environ.get("NEO4J_USERNAME") is not None - assert os.environ.get("NEO4J_PASSWORD") is not None - graph = Neo4jGraph() - # rewrite env for testing - os.environ["NEO4J_USERNAME"] = "foo" - message_store = Neo4jChatMessageHistory("23334", graph=graph) - message_store.clear() - assert len(message_store.messages) == 0 - message_store.add_user_message("Hello! Language Chain!") - message_store.add_ai_message("Hi Guys!") - # Now check if the messages are stored in the database correctly - assert len(message_store.messages) == 2 diff --git a/libs/community/tests/integration_tests/chat_models/test_anthropic.py b/libs/community/tests/integration_tests/chat_models/test_anthropic.py deleted file mode 100644 index 38a5bbb75..000000000 --- a/libs/community/tests/integration_tests/chat_models/test_anthropic.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Test Anthropic API wrapper.""" - -from typing import List - -import pytest -from langchain_core.callbacks import CallbackManager -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage -from langchain_core.outputs import ChatGeneration, LLMResult - -from langchain_community.chat_models.anthropic import ( - ChatAnthropic, -) -from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler - - -@pytest.mark.scheduled -def test_anthropic_call() -> None: - """Test valid call to anthropic.""" - chat = ChatAnthropic(model="test") # type: ignore[call-arg] - message = HumanMessage(content="Hello") - response = chat.invoke([message]) - assert isinstance(response, AIMessage) - assert isinstance(response.content, str) - - -@pytest.mark.scheduled -def test_anthropic_generate() -> None: - """Test generate method of anthropic.""" - chat = ChatAnthropic(model="test") # type: ignore[call-arg] - chat_messages: List[List[BaseMessage]] = [ - [HumanMessage(content="How many toes do dogs have?")] - ] - messages_copy = [messages.copy() for messages in chat_messages] - result: LLMResult = chat.generate(chat_messages) - assert isinstance(result, LLMResult) - for response in result.generations[0]: - assert isinstance(response, ChatGeneration) - assert isinstance(response.text, str) - assert response.text == response.message.content - assert chat_messages == messages_copy - - -@pytest.mark.scheduled -def test_anthropic_streaming() -> None: - """Test streaming tokens from anthropic.""" - chat = ChatAnthropic(model="test", streaming=True) # type: ignore[call-arg] - message = HumanMessage(content="Hello") - response = chat.invoke([message]) - assert isinstance(response, AIMessage) - assert isinstance(response.content, str) - - -@pytest.mark.scheduled -def test_anthropic_streaming_callback() -> None: - """Test that streaming correctly invokes on_llm_new_token callback.""" - callback_handler = FakeCallbackHandler() - callback_manager = CallbackManager([callback_handler]) - chat = ChatAnthropic( # type: ignore[call-arg] - model="test", - streaming=True, - callback_manager=callback_manager, - verbose=True, - ) - message = HumanMessage(content="Write me a sentence with 10 words.") - chat.invoke([message]) - assert callback_handler.llm_streams > 1 - - -@pytest.mark.scheduled -async def test_anthropic_async_streaming_callback() -> None: - """Test that streaming correctly invokes on_llm_new_token callback.""" - callback_handler = FakeCallbackHandler() - callback_manager = CallbackManager([callback_handler]) - chat = ChatAnthropic( # type: ignore[call-arg] - model="test", - streaming=True, - callback_manager=callback_manager, - verbose=True, - ) - chat_messages: List[BaseMessage] = [ - HumanMessage(content="How many toes do dogs have?") - ] - result: LLMResult = await chat.agenerate([chat_messages]) - assert callback_handler.llm_streams > 1 - assert isinstance(result, LLMResult) - for response in result.generations[0]: - assert isinstance(response, ChatGeneration) - assert isinstance(response.text, str) - assert response.text == response.message.content diff --git a/libs/community/tests/integration_tests/chat_models/test_azure_openai.py b/libs/community/tests/integration_tests/chat_models/test_azure_openai.py deleted file mode 100644 index b472736bf..000000000 --- a/libs/community/tests/integration_tests/chat_models/test_azure_openai.py +++ /dev/null @@ -1,222 +0,0 @@ -"""Test AzureChatOpenAI wrapper.""" - -import os -from typing import Any - -import pytest -from langchain_core.callbacks import CallbackManager -from langchain_core.messages import BaseMessage, HumanMessage -from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult - -from langchain_community.chat_models import AzureChatOpenAI -from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler - -OPENAI_API_VERSION = os.environ.get("AZURE_OPENAI_API_VERSION", "") -OPENAI_API_BASE = os.environ.get("AZURE_OPENAI_API_BASE", "") -OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY", "") -DEPLOYMENT_NAME = os.environ.get( - "AZURE_OPENAI_DEPLOYMENT_NAME", - os.environ.get("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", ""), -) - - -def _get_llm(**kwargs: Any) -> AzureChatOpenAI: - return AzureChatOpenAI( # type: ignore[call-arg] - deployment_name=DEPLOYMENT_NAME, - openai_api_version=OPENAI_API_VERSION, - azure_endpoint=OPENAI_API_BASE, - openai_api_key=OPENAI_API_KEY, - **kwargs, - ) - - -@pytest.mark.scheduled -@pytest.fixture -def llm() -> AzureChatOpenAI: - return _get_llm( - max_tokens=10, - ) - - -def test_chat_openai(llm: AzureChatOpenAI) -> None: - """Test AzureChatOpenAI wrapper.""" - message = HumanMessage(content="Hello") - response = llm.invoke([message]) - assert isinstance(response, BaseMessage) - assert isinstance(response.content, str) - - -@pytest.mark.scheduled -def test_chat_openai_generate() -> None: - """Test AzureChatOpenAI wrapper with generate.""" - chat = _get_llm(max_tokens=10, n=2) - message = HumanMessage(content="Hello") - response = chat.generate([[message], [message]]) - assert isinstance(response, LLMResult) - assert len(response.generations) == 2 - for generations in response.generations: - assert len(generations) == 2 - for generation in generations: - assert isinstance(generation, ChatGeneration) - assert isinstance(generation.text, str) - assert generation.text == generation.message.content - - -@pytest.mark.scheduled -def test_chat_openai_multiple_completions() -> None: - """Test AzureChatOpenAI wrapper with multiple completions.""" - chat = _get_llm(max_tokens=10, n=5) - message = HumanMessage(content="Hello") - response = chat._generate([message]) - assert isinstance(response, ChatResult) - assert len(response.generations) == 5 - for generation in response.generations: - assert isinstance(generation.message, BaseMessage) - assert isinstance(generation.message.content, str) - - -@pytest.mark.scheduled -def test_chat_openai_streaming() -> None: - """Test that streaming correctly invokes on_llm_new_token callback.""" - callback_handler = FakeCallbackHandler() - callback_manager = CallbackManager([callback_handler]) - chat = _get_llm( - max_tokens=10, - streaming=True, - temperature=0, - callback_manager=callback_manager, - verbose=True, - ) - message = HumanMessage(content="Hello") - response = chat.invoke([message]) - assert callback_handler.llm_streams > 0 - assert isinstance(response, BaseMessage) - - -@pytest.mark.scheduled -def test_chat_openai_streaming_generation_info() -> None: - """Test that generation info is preserved when streaming.""" - - class _FakeCallback(FakeCallbackHandler): - saved_things: dict = {} - - def on_llm_end( - self, - *args: Any, - **kwargs: Any, - ) -> Any: - # Save the generation - self.saved_things["generation"] = args[0] - - callback = _FakeCallback() - callback_manager = CallbackManager([callback]) - chat = _get_llm( - max_tokens=2, - temperature=0, - callback_manager=callback_manager, - ) - list(chat.stream("hi")) - generation = callback.saved_things["generation"] - # `Hello!` is two tokens, assert that that is what is returned - assert generation.generations[0][0].text == "Hello!" - - -@pytest.mark.scheduled -async def test_async_chat_openai() -> None: - """Test async generation.""" - chat = _get_llm(max_tokens=10, n=2) - message = HumanMessage(content="Hello") - response = await chat.agenerate([[message], [message]]) - assert isinstance(response, LLMResult) - assert len(response.generations) == 2 - for generations in response.generations: - assert len(generations) == 2 - for generation in generations: - assert isinstance(generation, ChatGeneration) - assert isinstance(generation.text, str) - assert generation.text == generation.message.content - - -@pytest.mark.scheduled -async def test_async_chat_openai_streaming() -> None: - """Test that streaming correctly invokes on_llm_new_token callback.""" - callback_handler = FakeCallbackHandler() - callback_manager = CallbackManager([callback_handler]) - chat = _get_llm( - max_tokens=10, - streaming=True, - temperature=0, - callback_manager=callback_manager, - verbose=True, - ) - message = HumanMessage(content="Hello") - response = await chat.agenerate([[message], [message]]) - assert callback_handler.llm_streams > 0 - assert isinstance(response, LLMResult) - assert len(response.generations) == 2 - for generations in response.generations: - assert len(generations) == 1 - for generation in generations: - assert isinstance(generation, ChatGeneration) - assert isinstance(generation.text, str) - assert generation.text == generation.message.content - - -@pytest.mark.scheduled -def test_openai_streaming(llm: AzureChatOpenAI) -> None: - """Test streaming tokens from OpenAI.""" - - for token in llm.stream("I'm Pickle Rick"): - assert isinstance(token.content, str) - - -@pytest.mark.scheduled -async def test_openai_astream(llm: AzureChatOpenAI) -> None: - """Test streaming tokens from OpenAI.""" - async for token in llm.astream("I'm Pickle Rick"): - assert isinstance(token.content, str) - - -@pytest.mark.scheduled -async def test_openai_abatch(llm: AzureChatOpenAI) -> None: - """Test streaming tokens from AzureChatOpenAI.""" - - result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) - for token in result: - assert isinstance(token.content, str) - - -@pytest.mark.scheduled -async def test_openai_abatch_tags(llm: AzureChatOpenAI) -> None: - """Test batch tokens from AzureChatOpenAI.""" - - result = await llm.abatch( - ["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]} - ) - for token in result: - assert isinstance(token.content, str) - - -@pytest.mark.scheduled -def test_openai_batch(llm: AzureChatOpenAI) -> None: - """Test batch tokens from AzureChatOpenAI.""" - - result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) - for token in result: - assert isinstance(token.content, str) - - -@pytest.mark.scheduled -async def test_openai_ainvoke(llm: AzureChatOpenAI) -> None: - """Test invoke tokens from AzureChatOpenAI.""" - - result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) - assert isinstance(result.content, str) - - -@pytest.mark.scheduled -def test_openai_invoke(llm: AzureChatOpenAI) -> None: - """Test invoke tokens from AzureChatOpenAI.""" - - result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) - assert isinstance(result.content, str) diff --git a/libs/community/tests/integration_tests/chat_models/test_bedrock.py b/libs/community/tests/integration_tests/chat_models/test_bedrock.py deleted file mode 100644 index 46a9293d3..000000000 --- a/libs/community/tests/integration_tests/chat_models/test_bedrock.py +++ /dev/null @@ -1,164 +0,0 @@ -"""Test Bedrock chat model.""" - -from typing import Any, cast - -import pytest -from langchain_core.callbacks import CallbackManager -from langchain_core.messages import ( - AIMessageChunk, - BaseMessage, - HumanMessage, - SystemMessage, -) -from langchain_core.outputs import ChatGeneration, LLMResult - -from langchain_community.chat_models import BedrockChat -from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler - - -@pytest.fixture -def chat() -> BedrockChat: - return BedrockChat(model_id="anthropic.claude-v2", model_kwargs={"temperature": 0}) # type: ignore[call-arg] - - -@pytest.mark.scheduled -def test_chat_bedrock(chat: BedrockChat) -> None: - """Test BedrockChat wrapper.""" - system = SystemMessage(content="You are a helpful assistant.") - human = HumanMessage(content="Hello") - response = chat.invoke([system, human]) - assert isinstance(response, BaseMessage) - assert isinstance(response.content, str) - - -@pytest.mark.scheduled -def test_chat_bedrock_generate(chat: BedrockChat) -> None: - """Test BedrockChat wrapper with generate.""" - message = HumanMessage(content="Hello") - response = chat.generate([[message], [message]]) - assert isinstance(response, LLMResult) - assert len(response.generations) == 2 - for generations in response.generations: - for generation in generations: - assert isinstance(generation, ChatGeneration) - assert isinstance(generation.text, str) - assert generation.text == generation.message.content - - -@pytest.mark.scheduled -def test_chat_bedrock_generate_with_token_usage(chat: BedrockChat) -> None: - """Test BedrockChat wrapper with generate.""" - message = HumanMessage(content="Hello") - response = chat.generate([[message], [message]]) - assert isinstance(response, LLMResult) - assert isinstance(response.llm_output, dict) - - usage = response.llm_output["usage"] - assert usage["prompt_tokens"] == 20 - assert usage["completion_tokens"] > 0 - assert usage["total_tokens"] > 0 - - -@pytest.mark.scheduled -def test_chat_bedrock_streaming() -> None: - """Test that streaming correctly invokes on_llm_new_token callback.""" - callback_handler = FakeCallbackHandler() - callback_manager = CallbackManager([callback_handler]) - chat = BedrockChat( # type: ignore[call-arg] - model_id="anthropic.claude-v2", - streaming=True, - callback_manager=callback_manager, - verbose=True, - ) - message = HumanMessage(content="Hello") - response = chat.invoke([message]) - assert callback_handler.llm_streams > 0 - assert isinstance(response, BaseMessage) - - -@pytest.mark.scheduled -def test_chat_bedrock_streaming_generation_info() -> None: - """Test that generation info is preserved when streaming.""" - - class _FakeCallback(FakeCallbackHandler): - saved_things: dict = {} - - def on_llm_end( - self, - *args: Any, - **kwargs: Any, - ) -> Any: - # Save the generation - self.saved_things["generation"] = args[0] - - callback = _FakeCallback() - callback_manager = CallbackManager([callback]) - chat = BedrockChat( # type: ignore[call-arg] - model_id="anthropic.claude-v2", - callback_manager=callback_manager, - ) - list(chat.stream("hi")) - generation = callback.saved_things["generation"] - # `Hello!` is two tokens, assert that that is what is returned - assert generation.generations[0][0].text == "Hello!" - - -@pytest.mark.scheduled -def test_bedrock_streaming(chat: BedrockChat) -> None: - """Test streaming tokens from OpenAI.""" - - full = None - for token in chat.stream("I'm Pickle Rick"): - full = token if full is None else full + token - assert isinstance(token.content, str) - assert isinstance(cast(AIMessageChunk, full).content, str) - - -@pytest.mark.scheduled -async def test_bedrock_astream(chat: BedrockChat) -> None: - """Test streaming tokens from OpenAI.""" - - async for token in chat.astream("I'm Pickle Rick"): - assert isinstance(token.content, str) - - -@pytest.mark.scheduled -async def test_bedrock_abatch(chat: BedrockChat) -> None: - """Test streaming tokens from BedrockChat.""" - result = await chat.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) - for token in result: - assert isinstance(token.content, str) - - -@pytest.mark.scheduled -async def test_bedrock_abatch_tags(chat: BedrockChat) -> None: - """Test batch tokens from BedrockChat.""" - result = await chat.abatch( - ["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]} - ) - for token in result: - assert isinstance(token.content, str) - - -@pytest.mark.scheduled -def test_bedrock_batch(chat: BedrockChat) -> None: - """Test batch tokens from BedrockChat.""" - result = chat.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) - for token in result: - assert isinstance(token.content, str) - - -@pytest.mark.scheduled -async def test_bedrock_ainvoke(chat: BedrockChat) -> None: - """Test invoke tokens from BedrockChat.""" - result = await chat.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) - assert isinstance(result.content, str) - - -@pytest.mark.scheduled -def test_bedrock_invoke(chat: BedrockChat) -> None: - """Test invoke tokens from BedrockChat.""" - result = chat.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) - assert isinstance(result.content, str) - assert all([k in result.response_metadata for k in ("usage", "model_id")]) - assert result.response_metadata["usage"]["prompt_tokens"] == 13 diff --git a/libs/community/tests/integration_tests/chat_models/test_fireworks.py b/libs/community/tests/integration_tests/chat_models/test_fireworks.py deleted file mode 100644 index da60c2020..000000000 --- a/libs/community/tests/integration_tests/chat_models/test_fireworks.py +++ /dev/null @@ -1,179 +0,0 @@ -"""Test ChatFireworks wrapper.""" - -import sys -from typing import cast - -import pytest -from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage -from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult - -from langchain_community.chat_models.fireworks import ChatFireworks - -if sys.version_info < (3, 9): - pytest.skip("fireworks-ai requires Python > 3.8", allow_module_level=True) - - -@pytest.fixture -def chat() -> ChatFireworks: - return ChatFireworks(model_kwargs={"temperature": 0, "max_tokens": 512}) - - -@pytest.mark.scheduled -def test_chat_fireworks(chat: ChatFireworks) -> None: - """Test ChatFireworks wrapper.""" - message = HumanMessage(content="What is the weather in Redwood City, CA today") - response = chat.invoke([message]) - assert isinstance(response, BaseMessage) - assert isinstance(response.content, str) - - -@pytest.mark.scheduled -def test_chat_fireworks_model() -> None: - """Test ChatFireworks wrapper handles model_name.""" - chat = ChatFireworks(model="foo") - assert chat.model == "foo" - - -@pytest.mark.scheduled -def test_chat_fireworks_system_message(chat: ChatFireworks) -> None: - """Test ChatFireworks wrapper with system message.""" - system_message = SystemMessage(content="You are to chat with the user.") - human_message = HumanMessage(content="Hello") - response = chat.invoke([system_message, human_message]) - assert isinstance(response, BaseMessage) - assert isinstance(response.content, str) - - -@pytest.mark.scheduled -def test_chat_fireworks_generate() -> None: - """Test ChatFireworks wrapper with generate.""" - chat = ChatFireworks(model_kwargs={"n": 2}) - message = HumanMessage(content="Hello") - response = chat.generate([[message], [message]]) - assert isinstance(response, LLMResult) - assert len(response.generations) == 2 - for generations in response.generations: - assert len(generations) == 2 - for generation in generations: - assert isinstance(generation, ChatGeneration) - assert isinstance(generation.text, str) - assert generation.text == generation.message.content - - -@pytest.mark.scheduled -def test_chat_fireworks_multiple_completions() -> None: - """Test ChatFireworks wrapper with multiple completions.""" - chat = ChatFireworks(model_kwargs={"n": 5}) - message = HumanMessage(content="Hello") - response = chat._generate([message]) - assert isinstance(response, ChatResult) - assert len(response.generations) == 5 - for generation in response.generations: - assert isinstance(generation.message, BaseMessage) - assert isinstance(generation.message.content, str) - - -@pytest.mark.scheduled -def test_chat_fireworks_llm_output_contains_model_id(chat: ChatFireworks) -> None: - """Test llm_output contains model_id.""" - message = HumanMessage(content="Hello") - llm_result = chat.generate([[message]]) - assert llm_result.llm_output is not None - assert llm_result.llm_output["model"] == chat.model - - -@pytest.mark.scheduled -def test_fireworks_invoke(chat: ChatFireworks) -> None: - """Tests chat completion with invoke""" - result = chat.invoke("How is the weather in New York today?", stop=[","]) - assert isinstance(result.content, str) - assert result.content[-1] == "," - - -@pytest.mark.scheduled -async def test_fireworks_ainvoke(chat: ChatFireworks) -> None: - """Tests chat completion with invoke""" - result = await chat.ainvoke("How is the weather in New York today?", stop=[","]) - assert isinstance(result.content, str) - assert result.content[-1] == "," - - -@pytest.mark.scheduled -def test_fireworks_batch(chat: ChatFireworks) -> None: - """Test batch tokens from ChatFireworks.""" - result = chat.batch( - [ - "What is the weather in Redwood City, CA today?", - "What is the weather in Redwood City, CA today?", - "What is the weather in Redwood City, CA today?", - ], - config={"max_concurrency": 2}, - stop=[","], - ) - for token in result: - assert isinstance(token.content, str) - assert token.content[-1] == ",", token.content - - -@pytest.mark.scheduled -async def test_fireworks_abatch(chat: ChatFireworks) -> None: - """Test batch tokens from ChatFireworks.""" - result = await chat.abatch( - [ - "What is the weather in Redwood City, CA today?", - "What is the weather in Redwood City, CA today?", - ], - config={"max_concurrency": 5}, - stop=[","], - ) - for token in result: - assert isinstance(token.content, str) - assert token.content[-1] == "," - - -@pytest.mark.scheduled -def test_fireworks_streaming(chat: ChatFireworks) -> None: - """Test streaming tokens from Fireworks.""" - - for token in chat.stream("I'm Pickle Rick"): - assert isinstance(token.content, str) - - -@pytest.mark.scheduled -def test_fireworks_streaming_stop_words(chat: ChatFireworks) -> None: - """Test streaming tokens with stop words.""" - - last_token = "" - for token in chat.stream("I'm Pickle Rick", stop=[","]): - last_token = cast(str, token.content) - assert isinstance(token.content, str) - assert last_token[-1] == "," - - -@pytest.mark.scheduled -async def test_chat_fireworks_agenerate() -> None: - """Test ChatFireworks wrapper with generate.""" - chat = ChatFireworks(model_kwargs={"n": 2}) - message = HumanMessage(content="Hello") - response = await chat.agenerate([[message], [message]]) - assert isinstance(response, LLMResult) - assert len(response.generations) == 2 - for generations in response.generations: - assert len(generations) == 2 - for generation in generations: - assert isinstance(generation, ChatGeneration) - assert isinstance(generation.text, str) - assert generation.text == generation.message.content - - -@pytest.mark.scheduled -async def test_fireworks_astream(chat: ChatFireworks) -> None: - """Test streaming tokens from Fireworks.""" - - last_token = "" - async for token in chat.astream( - "Who's the best quarterback in the NFL?", stop=[","] - ): - last_token = cast(str, token.content) - assert isinstance(token.content, str) - assert last_token[-1] == "," diff --git a/libs/community/tests/integration_tests/chat_models/test_litellm.py b/libs/community/tests/integration_tests/chat_models/test_litellm.py deleted file mode 100644 index e17d3d8d9..000000000 --- a/libs/community/tests/integration_tests/chat_models/test_litellm.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Test Anthropic API wrapper.""" - -from typing import List - -from langchain_core.callbacks import ( - CallbackManager, -) -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage -from langchain_core.outputs import ChatGeneration, LLMResult - -from langchain_community.chat_models.litellm import ChatLiteLLM -from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler - - -def test_litellm_call() -> None: - """Test valid call to litellm.""" - chat = ChatLiteLLM( - model="test", - ) - message = HumanMessage(content="Hello") - response = chat.invoke([message]) - assert isinstance(response, AIMessage) - assert isinstance(response.content, str) - - -def test_litellm_generate() -> None: - """Test generate method of anthropic.""" - chat = ChatLiteLLM(model="test") - chat_messages: List[List[BaseMessage]] = [ - [HumanMessage(content="How many toes do dogs have?")] - ] - messages_copy = [messages.copy() for messages in chat_messages] - result: LLMResult = chat.generate(chat_messages) - assert isinstance(result, LLMResult) - for response in result.generations[0]: - assert isinstance(response, ChatGeneration) - assert isinstance(response.text, str) - assert response.text == response.message.content - assert chat_messages == messages_copy - - -def test_litellm_streaming() -> None: - """Test streaming tokens from anthropic.""" - chat = ChatLiteLLM(model="test", streaming=True) - message = HumanMessage(content="Hello") - response = chat.invoke([message]) - assert isinstance(response, AIMessage) - assert isinstance(response.content, str) - - -def test_litellm_streaming_callback() -> None: - """Test that streaming correctly invokes on_llm_new_token callback.""" - callback_handler = FakeCallbackHandler() - callback_manager = CallbackManager([callback_handler]) - chat = ChatLiteLLM( - model="test", - streaming=True, - callbacks=callback_manager, - verbose=True, - ) - message = HumanMessage(content="Write me a sentence with 10 words.") - chat.invoke([message]) - assert callback_handler.llm_streams > 1 diff --git a/libs/community/tests/integration_tests/chat_models/test_litellm_router.py b/libs/community/tests/integration_tests/chat_models/test_litellm_router.py deleted file mode 100644 index c2d8ce85e..000000000 --- a/libs/community/tests/integration_tests/chat_models/test_litellm_router.py +++ /dev/null @@ -1,377 +0,0 @@ -"""Test LiteLLM Router API wrapper.""" - -import asyncio -import queue -import threading -from copy import deepcopy -from typing import ( - Any, - AsyncGenerator, - AsyncIterator, - Dict, - Generator, - List, - Tuple, - Union, - cast, -) - -import pytest -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage -from langchain_core.outputs import ChatGeneration, LLMResult - -from langchain_community.chat_models.litellm_router import ChatLiteLLMRouter -from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler - -model_group_gpt4 = "gpt-4" -model_group_to_test = "gpt-35-turbo" -fake_model_prefix = "azure/fake-deployment-name-" -fake_models_names = [fake_model_prefix + suffix for suffix in ["1", "2"]] -fake_api_key = "fakekeyvalue" -fake_api_version = "XXXX-XX-XX" -fake_api_base = "https://faketesturl/" -fake_chunks = ["This is ", "a fake answer."] -fake_answer = "".join(fake_chunks) -token_usage_key_name = "token_usage" - -model_list = [ - { - "model_name": model_group_gpt4, - "litellm_params": { - "model": fake_models_names[0], - "api_key": fake_api_key, - "api_version": fake_api_version, - "api_base": fake_api_base, - }, - }, - { - "model_name": model_group_to_test, - "litellm_params": { - "model": fake_models_names[1], - "api_key": fake_api_key, - "api_version": fake_api_version, - "api_base": fake_api_base, - }, - }, -] - - -# from https://stackoverflow.com/a/78573267 -def aiter_to_iter(it: AsyncIterator) -> Generator: - "Convert an async iterator into a regular (sync) iterator." - q_in: queue.SimpleQueue = queue.SimpleQueue() - q_out: queue.SimpleQueue = queue.SimpleQueue() - - async def threadmain() -> None: - try: - # Wait until the sync generator requests an item before continuing - while q_in.get(): - q_out.put((True, await it.__anext__())) - except StopAsyncIteration: - q_out.put((False, None)) - except BaseException as ex: - q_out.put((False, ex)) - - thread = threading.Thread(target=asyncio.run, args=(threadmain(),), daemon=True) - thread.start() - - try: - while True: - q_in.put(True) - cont, result = q_out.get() - if cont: - yield result - elif result is None: - break - else: - raise result - finally: - q_in.put(False) - - -class FakeCompletion: - def __init__(self) -> None: - self.seen_inputs: List[Any] = [] - - @staticmethod - def _get_new_result_and_choices( - base_result: Dict[str, Any], - ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: - result = deepcopy(base_result) - choices = cast(List[Dict[str, Any]], result["choices"]) - return result, choices - - async def _get_fake_results_agenerator( - self, **kwargs: Any - ) -> AsyncGenerator[Dict[str, Any], None]: - from litellm import Usage - - self.seen_inputs.append(kwargs) - base_result = { - "choices": [ - { - "index": 0, - } - ], - "created": 0, - "id": "", - "model": model_group_to_test, - "object": "chat.completion", - } - if kwargs["stream"]: - for chunk_index in range(0, len(fake_chunks)): - result, choices = self._get_new_result_and_choices(base_result) - choice = choices[0] - choice["delta"] = { - "role": "assistant", - "content": fake_chunks[chunk_index], - "function_call": None, - } - choice["finish_reason"] = None - # no usage here, since no usage from OpenAI API for streaming yet - # https://community.openai.com/t/usage-info-in-api-responses/18862 - yield result - - result, choices = self._get_new_result_and_choices(base_result) - choice = choices[0] - choice["delta"] = {} - choice["finish_reason"] = "stop" - # no usage here, since no usage from OpenAI API for streaming yet - # https://community.openai.com/t/usage-info-in-api-responses/18862 - yield result - else: - result, choices = self._get_new_result_and_choices(base_result) - choice = choices[0] - choice["message"] = { - "content": fake_answer, - "role": "assistant", - } - choice["finish_reason"] = "stop" - result["usage"] = Usage( - completion_tokens=1, prompt_tokens=2, total_tokens=3 - ) - yield result - - def completion(self, **kwargs: Any) -> Union[List, Dict[str, Any]]: - agen = self._get_fake_results_agenerator(**kwargs) - synchronous_iter = aiter_to_iter(agen) - if kwargs["stream"]: - results: List[Dict[str, Any]] = [] - while True: - try: - results.append(synchronous_iter.__next__()) - except StopIteration: - break - return results - else: - # there is only one result for non-streaming - return synchronous_iter.__next__() - - async def acompletion( - self, **kwargs: Any - ) -> Union[AsyncGenerator[Dict[str, Any], None], Dict[str, Any]]: - agen = self._get_fake_results_agenerator(**kwargs) - if kwargs["stream"]: - return agen - else: - # there is only one result for non-streaming - return await agen.__anext__() - - def check_inputs(self, expected_num_calls: int) -> None: - assert len(self.seen_inputs) == expected_num_calls - for kwargs in self.seen_inputs: - metadata = kwargs["metadata"] - - assert metadata["model_group"] == model_group_to_test - - # LiteLLM router chooses one model name from the model_list - assert kwargs["model"] in fake_models_names - assert metadata["deployment"] in fake_models_names - - assert kwargs["api_key"] == fake_api_key - assert kwargs["api_version"] == fake_api_version - assert kwargs["api_base"] == fake_api_base - - -@pytest.fixture -def fake_completion() -> FakeCompletion: - """Fake AI completion for testing.""" - import litellm - - fake_completion = FakeCompletion() - - # Turn off LiteLLM's built-in telemetry - litellm.telemetry = False - litellm.completion = fake_completion.completion - litellm.acompletion = fake_completion.acompletion - return fake_completion - - -@pytest.fixture -def litellm_router() -> Any: - """LiteLLM router for testing.""" - from litellm import Router - - return Router(model_list=model_list) - - -@pytest.mark.scheduled -@pytest.mark.enable_socket -def test_litellm_router_call( - fake_completion: FakeCompletion, litellm_router: Any -) -> None: - """Test valid call to LiteLLM Router.""" - chat = ChatLiteLLMRouter(router=litellm_router, model_name=model_group_to_test) - message = HumanMessage(content="Hello") - - response = chat.invoke([message]) - - assert isinstance(response, AIMessage) - assert isinstance(response.content, str) - assert response.content == fake_answer - # no usage check here, since response is only an AIMessage - fake_completion.check_inputs(expected_num_calls=1) - - -@pytest.mark.scheduled -@pytest.mark.enable_socket -def test_litellm_router_generate( - fake_completion: FakeCompletion, litellm_router: Any -) -> None: - """Test generate method of LiteLLM Router.""" - chat = ChatLiteLLMRouter(router=litellm_router, model_name=model_group_to_test) - chat_messages: List[List[BaseMessage]] = [ - [HumanMessage(content="How many toes do dogs have?")] - ] - messages_copy = [messages.copy() for messages in chat_messages] - - result: LLMResult = chat.generate(chat_messages) - - assert isinstance(result, LLMResult) - for generations in result.generations: - assert len(generations) == 1 - for generation in generations: - assert isinstance(generation, ChatGeneration) - assert isinstance(generation.text, str) - assert generation.message.content == generation.text - assert generation.message.content == fake_answer - assert chat_messages == messages_copy - assert result.llm_output is not None - assert result.llm_output[token_usage_key_name] == { - "completion_tokens": 1, - "completion_tokens_details": None, - "prompt_tokens": 2, - "prompt_tokens_details": None, - "total_tokens": 3, - } - fake_completion.check_inputs(expected_num_calls=1) - - -@pytest.mark.scheduled -@pytest.mark.enable_socket -def test_litellm_router_streaming( - fake_completion: FakeCompletion, litellm_router: Any -) -> None: - """Test streaming tokens from LiteLLM Router.""" - chat = ChatLiteLLMRouter( - router=litellm_router, model_name=model_group_to_test, streaming=True - ) - message = HumanMessage(content="Hello") - - response = chat.invoke([message]) - - assert isinstance(response, AIMessage) - assert isinstance(response.content, str) - assert response.content == fake_answer - # no usage check here, since response is only an AIMessage - fake_completion.check_inputs(expected_num_calls=1) - - -@pytest.mark.scheduled -@pytest.mark.enable_socket -def test_litellm_router_streaming_callback( - fake_completion: FakeCompletion, litellm_router: Any -) -> None: - """Test that streaming correctly invokes on_llm_new_token callback.""" - callback_handler = FakeCallbackHandler() - chat = ChatLiteLLMRouter( - router=litellm_router, - model_name=model_group_to_test, - streaming=True, - callbacks=[callback_handler], - verbose=True, - ) - message = HumanMessage(content="Write me a sentence with 10 words.") - - response = chat.invoke([message]) - - assert callback_handler.llm_streams > 1 - assert isinstance(response, AIMessage) - assert isinstance(response.content, str) - assert response.content == fake_answer - # no usage check here, since response is only an AIMessage - fake_completion.check_inputs(expected_num_calls=1) - - -@pytest.mark.scheduled -@pytest.mark.enable_socket -async def test_async_litellm_router( - fake_completion: FakeCompletion, litellm_router: Any -) -> None: - """Test async generation.""" - chat = ChatLiteLLMRouter(router=litellm_router, model_name=model_group_to_test) - message = HumanMessage(content="Hello") - - response = await chat.agenerate([[message], [message]]) - - assert isinstance(response, LLMResult) - assert len(response.generations) == 2 - for generations in response.generations: - assert len(generations) == 1 - for generation in generations: - assert isinstance(generation, ChatGeneration) - assert isinstance(generation.text, str) - assert generation.message.content == generation.text - assert generation.message.content == fake_answer - assert response.llm_output is not None - assert response.llm_output[token_usage_key_name] == { - "completion_tokens": 2, - "completion_tokens_details": None, - "prompt_tokens": 4, - "prompt_tokens_details": None, - "total_tokens": 6, - } - fake_completion.check_inputs(expected_num_calls=2) - - -@pytest.mark.scheduled -@pytest.mark.enable_socket -async def test_async_litellm_router_streaming( - fake_completion: FakeCompletion, litellm_router: Any -) -> None: - """Test that streaming correctly invokes on_llm_new_token callback.""" - callback_handler = FakeCallbackHandler() - chat = ChatLiteLLMRouter( - router=litellm_router, - model_name=model_group_to_test, - streaming=True, - callbacks=[callback_handler], - verbose=True, - ) - message = HumanMessage(content="Hello") - - response = await chat.agenerate([[message], [message]]) - - assert callback_handler.llm_streams > 0 - assert isinstance(response, LLMResult) - assert len(response.generations) == 2 - for generations in response.generations: - assert len(generations) == 1 - for generation in generations: - assert isinstance(generation, ChatGeneration) - assert isinstance(generation.text, str) - assert generation.message.content == generation.text - assert generation.message.content == fake_answer - # no usage check here, since no usage from OpenAI API for streaming yet - # https://community.openai.com/t/usage-info-in-api-responses/18862 - fake_completion.check_inputs(expected_num_calls=2) diff --git a/libs/community/tests/integration_tests/chat_models/test_litellm_standard.py b/libs/community/tests/integration_tests/chat_models/test_litellm_standard.py deleted file mode 100644 index d034ece43..000000000 --- a/libs/community/tests/integration_tests/chat_models/test_litellm_standard.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Standard LangChain interface tests""" - -from typing import Type - -from langchain_core.language_models import BaseChatModel -from langchain_tests.integration_tests import ChatModelIntegrationTests - -from langchain_community.chat_models.litellm import ChatLiteLLM - - -class TestLiteLLMStandard(ChatModelIntegrationTests): - @property - def chat_model_class(self) -> Type[BaseChatModel]: - return ChatLiteLLM - - @property - def chat_model_params(self) -> dict: - return { - "model": "ollama/mistral", - # Needed to get the usage object when streaming. See https://docs.litellm.ai/docs/completion/usage#streaming-usage - "model_kwargs": {"stream_options": {"include_usage": True}}, - } diff --git a/libs/community/tests/integration_tests/chat_models/test_llamacpp.py b/libs/community/tests/integration_tests/chat_models/test_llamacpp.py index 589f6e306..23983ca25 100644 --- a/libs/community/tests/integration_tests/chat_models/test_llamacpp.py +++ b/libs/community/tests/integration_tests/chat_models/test_llamacpp.py @@ -11,7 +11,6 @@ class Joke(BaseModel): # TODO: replace with standard integration tests -# See example in tests/integration_tests/chat_models/test_litellm.py def test_structured_output() -> None: llm = ChatLlamaCpp(model_path="/path/to/Meta-Llama-3.1-8B-Instruct.Q4_K_M.gguf") structured_llm = llm.with_structured_output(Joke) diff --git a/libs/community/tests/integration_tests/chat_models/test_openai.py b/libs/community/tests/integration_tests/chat_models/test_openai.py deleted file mode 100644 index 6add0bd2a..000000000 --- a/libs/community/tests/integration_tests/chat_models/test_openai.py +++ /dev/null @@ -1,333 +0,0 @@ -"""Test ChatOpenAI wrapper.""" - -from typing import Any, Optional - -import pytest -from langchain_core.callbacks import CallbackManager -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage -from langchain_core.outputs import ( - ChatGeneration, - ChatResult, - LLMResult, -) -from langchain_core.prompts import ChatPromptTemplate -from pydantic import BaseModel, Field - -from langchain_community.chat_models.openai import ChatOpenAI -from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler - - -@pytest.mark.scheduled -def test_chat_openai() -> None: - """Test ChatOpenAI wrapper.""" - chat = ChatOpenAI( - temperature=0.7, - base_url=None, - organization=None, - openai_proxy=None, - timeout=10.0, - max_retries=3, - http_client=None, - n=1, - max_tokens=10, - default_headers=None, - default_query=None, - ) - message = HumanMessage(content="Hello") - response = chat.invoke([message]) - assert isinstance(response, BaseMessage) - assert isinstance(response.content, str) - - -def test_chat_openai_model() -> None: - """Test ChatOpenAI wrapper handles model_name.""" - chat = ChatOpenAI(model="foo") - assert chat.model_name == "foo" - chat = ChatOpenAI(model_name="bar") # type: ignore[call-arg] - assert chat.model_name == "bar" - - -def test_chat_openai_system_message() -> None: - """Test ChatOpenAI wrapper with system message.""" - chat = ChatOpenAI(max_tokens=10) - system_message = SystemMessage(content="You are to chat with the user.") - human_message = HumanMessage(content="Hello") - response = chat.invoke([system_message, human_message]) - assert isinstance(response, BaseMessage) - assert isinstance(response.content, str) - - -@pytest.mark.scheduled -def test_chat_openai_generate() -> None: - """Test ChatOpenAI wrapper with generate.""" - chat = ChatOpenAI(max_tokens=10, n=2) - message = HumanMessage(content="Hello") - response = chat.generate([[message], [message]]) - assert isinstance(response, LLMResult) - assert len(response.generations) == 2 - assert response.llm_output - for generations in response.generations: - assert len(generations) == 2 - for generation in generations: - assert isinstance(generation, ChatGeneration) - assert isinstance(generation.text, str) - assert generation.text == generation.message.content - - -@pytest.mark.scheduled -def test_chat_openai_multiple_completions() -> None: - """Test ChatOpenAI wrapper with multiple completions.""" - chat = ChatOpenAI(max_tokens=10, n=5) - message = HumanMessage(content="Hello") - response = chat._generate([message]) - assert isinstance(response, ChatResult) - assert len(response.generations) == 5 - for generation in response.generations: - assert isinstance(generation.message, BaseMessage) - assert isinstance(generation.message.content, str) - - -@pytest.mark.scheduled -def test_chat_openai_streaming() -> None: - """Test that streaming correctly invokes on_llm_new_token callback.""" - callback_handler = FakeCallbackHandler() - callback_manager = CallbackManager([callback_handler]) - chat = ChatOpenAI( - max_tokens=10, - streaming=True, - temperature=0, - callbacks=callback_manager, - verbose=True, - ) - message = HumanMessage(content="Hello") - response = chat.invoke([message]) - assert callback_handler.llm_streams > 0 - assert isinstance(response, BaseMessage) - - -@pytest.mark.scheduled -def test_chat_openai_streaming_generation_info() -> None: - """Test that generation info is preserved when streaming.""" - - class _FakeCallback(FakeCallbackHandler): - saved_things: dict = {} - - def on_llm_end( - self, - *args: Any, - **kwargs: Any, - ) -> Any: - # Save the generation - self.saved_things["generation"] = args[0] - - callback = _FakeCallback() - callback_manager = CallbackManager([callback]) - chat = ChatOpenAI( - max_tokens=2, - temperature=0, - callbacks=callback_manager, - ) - list(chat.stream("hi")) - generation = callback.saved_things["generation"] - # `Hello!` is two tokens, assert that that is what is returned - assert generation.generations[0][0].text == "Hello!" - - -def test_chat_openai_llm_output_contains_model_name() -> None: - """Test llm_output contains model_name.""" - chat = ChatOpenAI(max_tokens=10) - message = HumanMessage(content="Hello") - llm_result = chat.generate([[message]]) - assert llm_result.llm_output is not None - assert llm_result.llm_output["model_name"] == chat.model_name - - -def test_chat_openai_streaming_llm_output_contains_model_name() -> None: - """Test llm_output contains model_name.""" - chat = ChatOpenAI(max_tokens=10, streaming=True) - message = HumanMessage(content="Hello") - llm_result = chat.generate([[message]]) - assert llm_result.llm_output is not None - assert llm_result.llm_output["model_name"] == chat.model_name - - -def test_chat_openai_invalid_streaming_params() -> None: - """Test that streaming correctly invokes on_llm_new_token callback.""" - with pytest.raises(ValueError): - ChatOpenAI( - max_tokens=10, - streaming=True, - temperature=0, - n=5, - ) - - -@pytest.mark.scheduled -async def test_async_chat_openai() -> None: - """Test async generation.""" - chat = ChatOpenAI(max_tokens=10, n=2) - message = HumanMessage(content="Hello") - response = await chat.agenerate([[message], [message]]) - assert isinstance(response, LLMResult) - assert len(response.generations) == 2 - assert response.llm_output - for generations in response.generations: - assert len(generations) == 2 - for generation in generations: - assert isinstance(generation, ChatGeneration) - assert isinstance(generation.text, str) - assert generation.text == generation.message.content - - -@pytest.mark.scheduled -async def test_async_chat_openai_streaming() -> None: - """Test that streaming correctly invokes on_llm_new_token callback.""" - callback_handler = FakeCallbackHandler() - callback_manager = CallbackManager([callback_handler]) - chat = ChatOpenAI( - max_tokens=10, - streaming=True, - temperature=0, - callbacks=callback_manager, - verbose=True, - ) - message = HumanMessage(content="Hello") - response = await chat.agenerate([[message], [message]]) - assert callback_handler.llm_streams > 0 - assert isinstance(response, LLMResult) - assert len(response.generations) == 2 - for generations in response.generations: - assert len(generations) == 1 - for generation in generations: - assert isinstance(generation, ChatGeneration) - assert isinstance(generation.text, str) - assert generation.text == generation.message.content - - -@pytest.mark.scheduled -async def test_async_chat_openai_bind_functions() -> None: - """Test ChatOpenAI wrapper with multiple completions.""" - - class Person(BaseModel): - """Identifying information about a person.""" - - name: str = Field(..., title="Name", description="The person's name") - age: int = Field(..., title="Age", description="The person's age") - fav_food: Optional[str] = Field( - default=None, title="Fav Food", description="The person's favorite food" - ) - - chat = ChatOpenAI( - max_tokens=30, - n=1, - streaming=True, - ).bind_functions(functions=[Person], function_call="Person") - - prompt = ChatPromptTemplate.from_messages( - [ - ("system", "Use the provided Person function"), - ("user", "{input}"), - ] - ) - - chain = prompt | chat - - message = HumanMessage(content="Sally is 13 years old") - response = await chain.abatch([{"input": message}]) - - assert isinstance(response, list) - assert len(response) == 1 - for generation in response: - assert isinstance(generation, AIMessage) - - -def test_chat_openai_extra_kwargs() -> None: - """Test extra kwargs to chat openai.""" - # Check that foo is saved in extra_kwargs. - llm = ChatOpenAI(foo=3, max_tokens=10) # type: ignore[call-arg] - assert llm.max_tokens == 10 - assert llm.model_kwargs == {"foo": 3} - - # Test that if extra_kwargs are provided, they are added to it. - llm = ChatOpenAI(foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg] - assert llm.model_kwargs == {"foo": 3, "bar": 2} - - # Test that if provided twice it errors - with pytest.raises(ValueError): - ChatOpenAI(foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg] - - # Test that if explicit param is specified in kwargs it errors - with pytest.raises(ValueError): - ChatOpenAI(model_kwargs={"temperature": 0.2}) - - # Test that "model" cannot be specified in kwargs - with pytest.raises(ValueError): - ChatOpenAI(model_kwargs={"model": "gpt-3.5-turbo-instruct"}) - - -@pytest.mark.scheduled -def test_openai_streaming() -> None: - """Test streaming tokens from OpenAI.""" - llm = ChatOpenAI(max_tokens=10) - - for token in llm.stream("I'm Pickle Rick"): - assert isinstance(token.content, str) - - -@pytest.mark.scheduled -async def test_openai_astream() -> None: - """Test streaming tokens from OpenAI.""" - llm = ChatOpenAI(max_tokens=10) - - async for token in llm.astream("I'm Pickle Rick"): - assert isinstance(token.content, str) - - -@pytest.mark.scheduled -async def test_openai_abatch() -> None: - """Test streaming tokens from ChatOpenAI.""" - llm = ChatOpenAI(max_tokens=10) - - result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) - for token in result: - assert isinstance(token.content, str) - - -@pytest.mark.scheduled -async def test_openai_abatch_tags() -> None: - """Test batch tokens from ChatOpenAI.""" - llm = ChatOpenAI(max_tokens=10) - - result = await llm.abatch( - ["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]} - ) - for token in result: - assert isinstance(token.content, str) - - -@pytest.mark.scheduled -def test_openai_batch() -> None: - """Test batch tokens from ChatOpenAI.""" - llm = ChatOpenAI(max_tokens=10) - - result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) - for token in result: - assert isinstance(token.content, str) - - -@pytest.mark.scheduled -async def test_openai_ainvoke() -> None: - """Test invoke tokens from ChatOpenAI.""" - llm = ChatOpenAI(max_tokens=10) - - result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) - assert isinstance(result.content, str) - - -@pytest.mark.scheduled -def test_openai_invoke() -> None: - """Test invoke tokens from ChatOpenAI.""" - llm = ChatOpenAI(max_tokens=10) - - result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) - assert isinstance(result.content, str) diff --git a/libs/community/tests/integration_tests/chat_models/test_perplexity.py b/libs/community/tests/integration_tests/chat_models/test_perplexity.py deleted file mode 100644 index 5288fccc9..000000000 --- a/libs/community/tests/integration_tests/chat_models/test_perplexity.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Standard LangChain interface tests""" - -from typing import Type - -import pytest -from langchain_core.language_models import BaseChatModel -from langchain_tests.integration_tests import ChatModelIntegrationTests - -from langchain_community.chat_models import ChatPerplexity - - -class TestPerplexityStandard(ChatModelIntegrationTests): - @property - def chat_model_class(self) -> Type[BaseChatModel]: - return ChatPerplexity - - @property - def chat_model_params(self) -> dict: - return {"model": "sonar"} - - @pytest.mark.xfail(reason="TODO: handle in integration.") - def test_double_messages_conversation(self, model: BaseChatModel) -> None: - super().test_double_messages_conversation(model) - - @pytest.mark.xfail(reason="Raises 400: Custom stop words not supported.") - def test_stop_sequence(self, model: BaseChatModel) -> None: - super().test_stop_sequence(model) diff --git a/libs/community/tests/integration_tests/chat_models/test_promptlayer_openai.py b/libs/community/tests/integration_tests/chat_models/test_promptlayer_openai.py index 3df4d370c..50187a753 100644 --- a/libs/community/tests/integration_tests/chat_models/test_promptlayer_openai.py +++ b/libs/community/tests/integration_tests/chat_models/test_promptlayer_openai.py @@ -1,7 +1,6 @@ """Test PromptLayerChatOpenAI wrapper.""" import pytest -from langchain_core.callbacks import CallbackManager from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult @@ -11,7 +10,7 @@ def test_promptlayer_chat_openai() -> None: """Test PromptLayerChatOpenAI wrapper.""" - chat = PromptLayerChatOpenAI(max_tokens=10) # type: ignore[call-arg] + chat = PromptLayerChatOpenAI(max_tokens=10) message = HumanMessage(content="Hello") response = chat.invoke([message]) assert isinstance(response, BaseMessage) @@ -20,7 +19,7 @@ def test_promptlayer_chat_openai() -> None: def test_promptlayer_chat_openai_system_message() -> None: """Test PromptLayerChatOpenAI wrapper with system message.""" - chat = PromptLayerChatOpenAI(max_tokens=10) # type: ignore[call-arg] + chat = PromptLayerChatOpenAI(max_tokens=10) system_message = SystemMessage(content="You are to chat with the user.") human_message = HumanMessage(content="Hello") response = chat.invoke([system_message, human_message]) @@ -30,7 +29,7 @@ def test_promptlayer_chat_openai_system_message() -> None: def test_promptlayer_chat_openai_generate() -> None: """Test PromptLayerChatOpenAI wrapper with generate.""" - chat = PromptLayerChatOpenAI(max_tokens=10, n=2) # type: ignore[call-arg] + chat = PromptLayerChatOpenAI(max_tokens=10, n=2) message = HumanMessage(content="Hello") response = chat.generate([[message], [message]]) assert isinstance(response, LLMResult) @@ -45,7 +44,7 @@ def test_promptlayer_chat_openai_generate() -> None: def test_promptlayer_chat_openai_multiple_completions() -> None: """Test PromptLayerChatOpenAI wrapper with multiple completions.""" - chat = PromptLayerChatOpenAI(max_tokens=10, n=5) # type: ignore[call-arg] + chat = PromptLayerChatOpenAI(max_tokens=10, n=5) message = HumanMessage(content="Hello") response = chat._generate([message]) assert isinstance(response, ChatResult) @@ -58,12 +57,11 @@ def test_promptlayer_chat_openai_multiple_completions() -> None: def test_promptlayer_chat_openai_streaming() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" callback_handler = FakeCallbackHandler() - callback_manager = CallbackManager([callback_handler]) - chat = PromptLayerChatOpenAI( # type: ignore[call-arg] + chat = PromptLayerChatOpenAI( max_tokens=10, streaming=True, temperature=0, - callback_manager=callback_manager, + callbacks=[callback_handler], verbose=True, ) message = HumanMessage(content="Hello") @@ -75,7 +73,7 @@ def test_promptlayer_chat_openai_streaming() -> None: def test_promptlayer_chat_openai_invalid_streaming_params() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" with pytest.raises(ValueError): - PromptLayerChatOpenAI( # type: ignore[call-arg] + PromptLayerChatOpenAI( max_tokens=10, streaming=True, temperature=0, @@ -85,7 +83,7 @@ def test_promptlayer_chat_openai_invalid_streaming_params() -> None: async def test_async_promptlayer_chat_openai() -> None: """Test async generation.""" - chat = PromptLayerChatOpenAI(max_tokens=10, n=2) # type: ignore[call-arg] + chat = PromptLayerChatOpenAI(max_tokens=10, n=2) message = HumanMessage(content="Hello") response = await chat.agenerate([[message], [message]]) assert isinstance(response, LLMResult) @@ -101,12 +99,11 @@ async def test_async_promptlayer_chat_openai() -> None: async def test_async_promptlayer_chat_openai_streaming() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" callback_handler = FakeCallbackHandler() - callback_manager = CallbackManager([callback_handler]) - chat = PromptLayerChatOpenAI( # type: ignore[call-arg] + chat = PromptLayerChatOpenAI( max_tokens=10, streaming=True, temperature=0, - callback_manager=callback_manager, + callbacks=[callback_handler], verbose=True, ) message = HumanMessage(content="Hello") diff --git a/libs/community/tests/integration_tests/chat_models/test_sambanova.py b/libs/community/tests/integration_tests/chat_models/test_sambanova.py deleted file mode 100644 index 2683b5dc3..000000000 --- a/libs/community/tests/integration_tests/chat_models/test_sambanova.py +++ /dev/null @@ -1,22 +0,0 @@ -from langchain_core.messages import AIMessage, HumanMessage - -from langchain_community.chat_models.sambanova import ( - ChatSambaNovaCloud, - ChatSambaStudio, -) - - -def test_chat_sambanova_cloud() -> None: - chat = ChatSambaNovaCloud() - message = HumanMessage(content="Hello") - response = chat.invoke([message]) - assert isinstance(response, AIMessage) - assert isinstance(response.content, str) - - -def test_chat_sambastudio() -> None: - chat = ChatSambaStudio() - message = HumanMessage(content="Hello") - response = chat.invoke([message]) - assert isinstance(response, AIMessage) - assert isinstance(response.content, str) diff --git a/libs/community/tests/integration_tests/chat_models/test_vertexai.py b/libs/community/tests/integration_tests/chat_models/test_vertexai.py deleted file mode 100644 index 585476de5..000000000 --- a/libs/community/tests/integration_tests/chat_models/test_vertexai.py +++ /dev/null @@ -1,294 +0,0 @@ -"""Test Vertex AI API wrapper. -In order to run this test, you need to install VertexAI SDK (that is is the private -preview) and be whitelisted to list the models themselves: -In order to run this test, you need to install VertexAI SDK -pip install google-cloud-aiplatform>=1.35.0 - -Your end-user credentials would be used to make the calls (make sure you've run -`gcloud auth login` first). -""" - -from typing import Optional -from unittest.mock import MagicMock, Mock, patch - -import pytest -from langchain_core.messages import ( - AIMessage, - AIMessageChunk, - HumanMessage, - SystemMessage, -) -from langchain_core.outputs import LLMResult - -from langchain_community.chat_models import ChatVertexAI -from langchain_community.chat_models.vertexai import ( - _parse_chat_history, - _parse_examples, -) - -model_names_to_test = [None, "codechat-bison", "chat-bison", "gemini-pro"] - - -@pytest.mark.parametrize("model_name", model_names_to_test) -def test_vertexai_instantiation(model_name: str) -> None: - if model_name: - model = ChatVertexAI(model_name=model_name) - else: - model = ChatVertexAI() - assert model._llm_type == "vertexai" - try: - assert model.model_name == model.client._model_id - except AttributeError: - assert model.model_name == model.client._model_name.split("/")[-1] - - -@pytest.mark.scheduled -@pytest.mark.parametrize("model_name", model_names_to_test) -def test_vertexai_single_call(model_name: str) -> None: - if model_name: - model = ChatVertexAI(model_name=model_name) - else: - model = ChatVertexAI() - message = HumanMessage(content="Hello") - response = model.invoke([message]) - assert isinstance(response, AIMessage) - assert isinstance(response.content, str) - - -# mark xfail because Vertex API randomly doesn't respect -# the n/candidate_count parameter -@pytest.mark.xfail -@pytest.mark.scheduled -def test_candidates() -> None: - model = ChatVertexAI(model_name="chat-bison@001", temperature=0.3, n=2) - message = HumanMessage(content="Hello") - response = model.generate(messages=[[message]]) - assert isinstance(response, LLMResult) - assert len(response.generations) == 1 - assert len(response.generations[0]) == 2 - - -@pytest.mark.scheduled -@pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"]) -async def test_vertexai_agenerate(model_name: str) -> None: - model = ChatVertexAI(temperature=0, model_name=model_name) - message = HumanMessage(content="Hello") - response = await model.agenerate([[message]]) - assert isinstance(response, LLMResult) - assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore[union-attr] - - sync_response = model.generate([[message]]) - assert response.generations[0][0] == sync_response.generations[0][0] - - -@pytest.mark.scheduled -@pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"]) -def test_vertexai_stream(model_name: str) -> None: - model = ChatVertexAI(temperature=0, model_name=model_name) - message = HumanMessage(content="Hello") - - sync_response = model.stream([message]) - for chunk in sync_response: - assert isinstance(chunk, AIMessageChunk) - - -@pytest.mark.scheduled -def test_vertexai_single_call_with_context() -> None: - model = ChatVertexAI() - raw_context = ( - "My name is Ned. You are my personal assistant. My favorite movies " - "are Lord of the Rings and Hobbit." - ) - question = ( - "Hello, could you recommend a good movie for me to watch this evening, please?" - ) - context = SystemMessage(content=raw_context) - message = HumanMessage(content=question) - response = model.invoke([context, message]) - assert isinstance(response, AIMessage) - assert isinstance(response.content, str) - - -def test_multimodal() -> None: - llm = ChatVertexAI(model_name="gemini-ultra-vision") - gcs_url = ( - "gs://cloud-samples-data/generative-ai/image/320px-Felis_catus-cat_on_snow.jpg" - ) - image_message = { - "type": "image_url", - "image_url": {"url": gcs_url}, - } - text_message = { - "type": "text", - "text": "What is shown in this image?", - } - message = HumanMessage(content=[text_message, image_message]) - output = llm.invoke([message]) - assert isinstance(output.content, str) - - -def test_multimodal_history() -> None: - llm = ChatVertexAI(model_name="gemini-ultra-vision") - gcs_url = ( - "gs://cloud-samples-data/generative-ai/image/320px-Felis_catus-cat_on_snow.jpg" - ) - image_message = { - "type": "image_url", - "image_url": {"url": gcs_url}, - } - text_message = { - "type": "text", - "text": "What is shown in this image?", - } - message1 = HumanMessage(content=[text_message, image_message]) - message2 = AIMessage( - content=( - "This is a picture of a cat in the snow. The cat is a tabby cat, which is " - "a type of cat with a striped coat. The cat is standing in the snow, and " - "its fur is covered in snow." - ) - ) - message3 = HumanMessage(content="What time of day is it?") - response = llm.invoke([message1, message2, message3]) - assert isinstance(response, AIMessage) - assert isinstance(response.content, str) - - -@pytest.mark.scheduled -def test_vertexai_single_call_with_examples() -> None: - model = ChatVertexAI() - raw_context = "My name is Ned. You are my personal assistant." - question = "2+2" - text_question, text_answer = "4+4", "8" - inp = HumanMessage(content=text_question) - output = AIMessage(content=text_answer) - context = SystemMessage(content=raw_context) - message = HumanMessage(content=question) - response = model.invoke([context, message], examples=[inp, output]) - assert isinstance(response, AIMessage) - assert isinstance(response.content, str) - - -@pytest.mark.scheduled -@pytest.mark.parametrize("model_name", model_names_to_test) -def test_vertexai_single_call_with_history(model_name: str) -> None: - if model_name: - model = ChatVertexAI(model_name=model_name) - else: - model = ChatVertexAI() - text_question1, text_answer1 = "How much is 2+2?", "4" - text_question2 = "How much is 3+3?" - message1 = HumanMessage(content=text_question1) - message2 = AIMessage(content=text_answer1) - message3 = HumanMessage(content=text_question2) - response = model.invoke([message1, message2, message3]) - assert isinstance(response, AIMessage) - assert isinstance(response.content, str) - - -def test_parse_chat_history_correct() -> None: - from vertexai.language_models import ChatMessage - - text_context = ( - "My name is Ned. You are my personal assistant. My " - "favorite movies are Lord of the Rings and Hobbit." - ) - context = SystemMessage(content=text_context) - text_question = ( - "Hello, could you recommend a good movie for me to watch this evening, please?" - ) - question = HumanMessage(content=text_question) - text_answer = ( - "Sure, You might enjoy The Lord of the Rings: The Fellowship of the Ring " - "(2001): This is the first movie in the Lord of the Rings trilogy." - ) - answer = AIMessage(content=text_answer) - history = _parse_chat_history([context, question, answer, question, answer]) - assert history.context == context.content - assert len(history.history) == 4 - assert history.history == [ - ChatMessage(content=text_question, author="user"), - ChatMessage(content=text_answer, author="bot"), - ChatMessage(content=text_question, author="user"), - ChatMessage(content=text_answer, author="bot"), - ] - - -def test_vertexai_single_call_fails_no_message() -> None: - chat = ChatVertexAI() - with pytest.raises(ValueError) as exc_info: - _ = chat.invoke([]) - assert ( - str(exc_info.value) - == "You should provide at least one message to start the chat!" - ) - - -@pytest.mark.parametrize("stop", [None, "stop1"]) -def test_vertexai_args_passed(stop: Optional[str]) -> None: - response_text = "Goodbye" - user_prompt = "Hello" - prompt_params = { - "max_output_tokens": 1, - "temperature": 10000.0, - "top_k": 10, - "top_p": 0.5, - } - - # Mock the library to ensure the args are passed correctly - with patch( - "vertexai.language_models._language_models.ChatModel.start_chat" - ) as start_chat: - mock_response = MagicMock() - mock_response.candidates = [Mock(text=response_text)] - mock_chat = MagicMock() - start_chat.return_value = mock_chat - mock_send_message = MagicMock(return_value=mock_response) - mock_chat.send_message = mock_send_message - - model = ChatVertexAI(**prompt_params) # type: ignore[arg-type] - message = HumanMessage(content=user_prompt) - if stop: - response = model.invoke([message], stop=[stop]) - else: - response = model.invoke([message]) - - assert response.content == response_text - mock_send_message.assert_called_once_with(user_prompt, candidate_count=1) - expected_stop_sequence = [stop] if stop else None - start_chat.assert_called_once_with( - context=None, - message_history=[], - **prompt_params, - stop_sequences=expected_stop_sequence, - ) - - -def test_parse_examples_correct() -> None: - from vertexai.language_models import InputOutputTextPair - - text_question = ( - "Hello, could you recommend a good movie for me to watch this evening, please?" - ) - question = HumanMessage(content=text_question) - text_answer = ( - "Sure, You might enjoy The Lord of the Rings: The Fellowship of the Ring " - "(2001): This is the first movie in the Lord of the Rings trilogy." - ) - answer = AIMessage(content=text_answer) - examples = _parse_examples([question, answer, question, answer]) - assert len(examples) == 2 - assert examples == [ - InputOutputTextPair(input_text=text_question, output_text=text_answer), - InputOutputTextPair(input_text=text_question, output_text=text_answer), - ] - - -def test_parse_examples_failes_wrong_sequence() -> None: - with pytest.raises(ValueError) as exc_info: - _ = _parse_examples([AIMessage(content="a")]) - print(str(exc_info.value)) # noqa: T201 - assert ( - str(exc_info.value) - == "Expect examples to have an even amount of messages, got 1." - ) diff --git a/libs/community/tests/integration_tests/document_loaders/test_oracleds.py b/libs/community/tests/integration_tests/document_loaders/test_oracleds.py deleted file mode 100644 index 498e4b692..000000000 --- a/libs/community/tests/integration_tests/document_loaders/test_oracleds.py +++ /dev/null @@ -1,447 +0,0 @@ -# Authors: -# Sudhir Kumar (sudhirkk) -# -# ----------------------------------------------------------------------------- -# test_oracleds.py -# ----------------------------------------------------------------------------- -import sys - -from langchain_community.document_loaders.oracleai import ( - OracleDocLoader, - OracleTextSplitter, -) -from langchain_community.utilities.oracleai import OracleSummary -from langchain_community.vectorstores.oraclevs import ( - _table_exists, - drop_table_purge, -) - -uname = "hr" -passwd = "hr" -# uname = "LANGCHAINUSER" -# passwd = "langchainuser" -v_dsn = "100.70.107.245:1521/cdb1_pdb1.regress.rdbms.dev.us.oracle.com" - - -### Test loader ##### -def test_loader_test() -> None: - try: - import oracledb - except ImportError: - return - - try: - # oracle connection - connection = oracledb.connect(user=uname, password=passwd, dsn=v_dsn) - cursor = connection.cursor() - - if _table_exists(connection, "LANGCHAIN_DEMO"): - drop_table_purge(connection, "LANGCHAIN_DEMO") - - cursor.execute("CREATE TABLE langchain_demo(id number, text varchar2(25))") - - rows = [ - (1, "First"), - (2, "Second"), - (3, "Third"), - (4, "Fourth"), - (5, "Fifth"), - (6, "Sixth"), - (7, "Seventh"), - ] - - cursor.executemany("insert into LANGCHAIN_DEMO(id, text) values (:1, :2)", rows) - - connection.commit() - - # local file, local directory, database column - loader_params = { - "owner": uname, - "tablename": "LANGCHAIN_DEMO", - "colname": "TEXT", - } - - # instantiate - loader = OracleDocLoader(conn=connection, params=loader_params) - - # load - docs = loader.load() - - # verify - if len(docs) == 0: - sys.exit(1) - - if _table_exists(connection, "LANGCHAIN_DEMO"): - drop_table_purge(connection, "LANGCHAIN_DEMO") - - except Exception: - sys.exit(1) - - try: - # expectation : ORA-00942 - loader_params = { - "owner": uname, - "tablename": "COUNTRIES1", - "colname": "COUNTRY_NAME", - } - - # instantiate - loader = OracleDocLoader(conn=connection, params=loader_params) - - # load - docs = loader.load() - if len(docs) == 0: - pass - - except Exception: - pass - - try: - # expectation : file "SUDHIR" doesn't exist. - loader_params = {"file": "SUDHIR"} - - # instantiate - loader = OracleDocLoader(conn=connection, params=loader_params) - - # load - docs = loader.load() - if len(docs) == 0: - pass - - except Exception: - pass - - try: - # expectation : path "SUDHIR" doesn't exist. - loader_params = {"dir": "SUDHIR"} - - # instantiate - loader = OracleDocLoader(conn=connection, params=loader_params) - - # load - docs = loader.load() - if len(docs) == 0: - pass - - except Exception: - pass - - -### Test splitter #### -def test_splitter_test() -> None: - try: - import oracledb - except ImportError: - return - - try: - # oracle connection - connection = oracledb.connect(user=uname, password=passwd, dsn=v_dsn) - doc = """Langchain is a wonderful framework to load, split, chunk - and embed your data!!""" - - # by words , max = 1000 - splitter_params = { - "by": "words", - "max": "1000", - "overlap": "200", - "split": "custom", - "custom_list": [","], - "extended": "true", - "normalize": "all", - } - - # instantiate - splitter = OracleTextSplitter(conn=connection, params=splitter_params) - - # generate chunks - chunks = splitter.split_text(doc) - - # verify - if len(chunks) == 0: - sys.exit(1) - - # by chars , max = 4000 - splitter_params = { - "by": "chars", - "max": "4000", - "overlap": "800", - "split": "NEWLINE", - "normalize": "all", - } - - # instantiate - splitter = OracleTextSplitter(conn=connection, params=splitter_params) - - # generate chunks - chunks = splitter.split_text(doc) - - # verify - if len(chunks) == 0: - sys.exit(1) - - # by words , max = 10 - splitter_params = { - "by": "words", - "max": "10", - "overlap": "2", - "split": "SENTENCE", - } - - # instantiate - splitter = OracleTextSplitter(conn=connection, params=splitter_params) - - # generate chunks - chunks = splitter.split_text(doc) - - # verify - if len(chunks) == 0: - sys.exit(1) - - # by chars , max = 50 - splitter_params = { - "by": "chars", - "max": "50", - "overlap": "10", - "split": "SPACE", - "normalize": "all", - } - - # instantiate - splitter = OracleTextSplitter(conn=connection, params=splitter_params) - - # generate chunks - chunks = splitter.split_text(doc) - - # verify - if len(chunks) == 0: - sys.exit(1) - - except Exception: - sys.exit(1) - - try: - # ORA-20003: invalid value xyz for BY parameter - splitter_params = {"by": "xyz"} - - # instantiate - splitter = OracleTextSplitter(conn=connection, params=splitter_params) - - # generate chunks - chunks = splitter.split_text(doc) - - # verify - if len(chunks) == 0: - pass - - except Exception: - pass - - try: - # Expectation: ORA-30584: invalid text chunking MAXIMUM - '10' - splitter_params = { - "by": "chars", - "max": "10", - "overlap": "2", - "split": "SPACE", - "normalize": "all", - } - - # instantiate - splitter = OracleTextSplitter(conn=connection, params=splitter_params) - - # generate chunks - chunks = splitter.split_text(doc) - - # verify - if len(chunks) == 0: - pass - - except Exception: - pass - - try: - # Expectation: ORA-30584: invalid text chunking MAXIMUM - '5' - splitter_params = { - "by": "words", - "max": "5", - "overlap": "2", - "split": "SPACE", - "normalize": "all", - } - - # instantiate - splitter = OracleTextSplitter(conn=connection, params=splitter_params) - - # generate chunks - chunks = splitter.split_text(doc) - - # verify - if len(chunks) == 0: - pass - - except Exception: - pass - - try: - # Expectation: ORA-30586: invalid text chunking SPLIT BY - SENTENCE - splitter_params = { - "by": "words", - "max": "50", - "overlap": "2", - "split": "SENTENCE", - "normalize": "all", - } - - # instantiate - splitter = OracleTextSplitter(conn=connection, params=splitter_params) - - # generate chunks - chunks = splitter.split_text(doc) - - # verify - if len(chunks) == 0: - pass - - except Exception: - pass - - -#### Test summary #### -def test_summary_test() -> None: - try: - import oracledb - except ImportError: - return - - try: - # oracle connection - connection = oracledb.connect(user=uname, password=passwd, dsn=v_dsn) - - # provider : Database, glevel : Paragraph - summary_params = { - "provider": "database", - "glevel": "paragraph", - "numParagraphs": 2, - "language": "english", - } - - # summary - summary = OracleSummary(conn=connection, params=summary_params) - - doc = """It was 7 minutes after midnight. The dog was lying on the grass in - of the lawn in front of Mrs Shears house. Its eyes were closed. It - was running on its side, the way dogs run when they think they are - cat in a dream. But the dog was not running or asleep. The dog was dead. - was a garden fork sticking out of the dog. The points of the fork must - gone all the way through the dog and into the ground because the fork - not fallen over. I decided that the dog was probably killed with the - because I could not see any other wounds in the dog and I do not think - would stick a garden fork into a dog after it had died for some other - like cancer for example, or a road accident. But I could not be certain""" - - summaries = summary.get_summary(doc) - - # verify - if len(summaries) == 0: - sys.exit(1) - - # provider : Database, glevel : Sentence - summary_params = {"provider": "database", "glevel": "Sentence"} - - # summary - summary = OracleSummary(conn=connection, params=summary_params) - summaries = summary.get_summary(doc) - - # verify - if len(summaries) == 0: - sys.exit(1) - - # provider : Database, glevel : P - summary_params = {"provider": "database", "glevel": "P"} - - # summary - summary = OracleSummary(conn=connection, params=summary_params) - summaries = summary.get_summary(doc) - - # verify - if len(summaries) == 0: - sys.exit(1) - - # provider : Database, glevel : S - summary_params = { - "provider": "database", - "glevel": "S", - "numParagraphs": 16, - "language": "english", - } - - # summary - summary = OracleSummary(conn=connection, params=summary_params) - summaries = summary.get_summary(doc) - - # verify - if len(summaries) == 0: - sys.exit(1) - - # provider : Database, glevel : S, doc = ' ' - summary_params = {"provider": "database", "glevel": "S", "numParagraphs": 2} - - # summary - summary = OracleSummary(conn=connection, params=summary_params) - - doc = " " - summaries = summary.get_summary(doc) - - # verify - if len(summaries) == 0: - sys.exit(1) - - except Exception: - sys.exit(1) - - try: - # Expectation : DRG-11002: missing value for PROVIDER - summary_params = {"provider": "database1", "glevel": "S"} - - # summary - summary = OracleSummary(conn=connection, params=summary_params) - summaries = summary.get_summary(doc) - - # verify - if len(summaries) == 0: - pass - - except Exception: - pass - - try: - # Expectation : DRG-11425: gist level SUDHIR is invalid, - # DRG-11427: valid gist level values are S, P - summary_params = {"provider": "database", "glevel": "SUDHIR"} - - # summary - summary = OracleSummary(conn=connection, params=summary_params) - summaries = summary.get_summary(doc) - - # verify - if len(summaries) == 0: - pass - - except Exception: - pass - - try: - # Expectation : DRG-11441: gist numParagraphs -2 is invalid - summary_params = {"provider": "database", "glevel": "S", "numParagraphs": -2} - - # summary - summary = OracleSummary(conn=connection, params=summary_params) - summaries = summary.get_summary(doc) - - # verify - if len(summaries) == 0: - pass - - except Exception: - pass diff --git a/libs/community/tests/integration_tests/memory/test_memory_astradb.py b/libs/community/tests/integration_tests/memory/test_memory_astradb.py deleted file mode 100644 index a31edfee2..000000000 --- a/libs/community/tests/integration_tests/memory/test_memory_astradb.py +++ /dev/null @@ -1,202 +0,0 @@ -import os -from typing import AsyncIterable, Iterable - -import pytest -from langchain_classic.memory import ConversationBufferMemory -from langchain_core.messages import AIMessage, HumanMessage - -from langchain_community.chat_message_histories.astradb import ( - AstraDBChatMessageHistory, -) -from langchain_community.utilities.astradb import SetupMode - - -def _has_env_vars() -> bool: - return all( - [ - "ASTRA_DB_APPLICATION_TOKEN" in os.environ, - "ASTRA_DB_API_ENDPOINT" in os.environ, - ] - ) - - -@pytest.fixture(scope="function") -def history1() -> Iterable[AstraDBChatMessageHistory]: - history1 = AstraDBChatMessageHistory( - session_id="session-test-1", - collection_name="langchain_cmh_test", - token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], - api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], - namespace=os.environ.get("ASTRA_DB_KEYSPACE"), - ) - yield history1 - history1.collection.astra_db.delete_collection("langchain_cmh_test") - - -@pytest.fixture(scope="function") -def history2() -> Iterable[AstraDBChatMessageHistory]: - history2 = AstraDBChatMessageHistory( - session_id="session-test-2", - collection_name="langchain_cmh_test", - token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], - api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], - namespace=os.environ.get("ASTRA_DB_KEYSPACE"), - ) - yield history2 - history2.collection.astra_db.delete_collection("langchain_cmh_test") - - -@pytest.fixture -async def async_history1() -> AsyncIterable[AstraDBChatMessageHistory]: - history1 = AstraDBChatMessageHistory( - session_id="async-session-test-1", - collection_name="langchain_cmh_test", - token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], - api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], - namespace=os.environ.get("ASTRA_DB_KEYSPACE"), - setup_mode=SetupMode.ASYNC, - ) - yield history1 - await history1.async_collection.astra_db.delete_collection("langchain_cmh_test") - - -@pytest.fixture(scope="function") -async def async_history2() -> AsyncIterable[AstraDBChatMessageHistory]: - history2 = AstraDBChatMessageHistory( - session_id="async-session-test-2", - collection_name="langchain_cmh_test", - token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], - api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], - namespace=os.environ.get("ASTRA_DB_KEYSPACE"), - setup_mode=SetupMode.ASYNC, - ) - yield history2 - await history2.async_collection.astra_db.delete_collection("langchain_cmh_test") - - -@pytest.mark.requires("astrapy") -@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") -def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None: - """Test the memory with a message store.""" - memory = ConversationBufferMemory( - memory_key="baz", - chat_memory=history1, - return_messages=True, - ) - - assert memory.chat_memory.messages == [] - - # add some messages - memory.chat_memory.add_messages( - [ - AIMessage(content="This is me, the AI"), - HumanMessage(content="This is me, the human"), - ] - ) - - messages = memory.chat_memory.messages - expected = [ - AIMessage(content="This is me, the AI"), - HumanMessage(content="This is me, the human"), - ] - assert messages == expected - - # clear the store - memory.chat_memory.clear() - - assert memory.chat_memory.messages == [] - - -@pytest.mark.requires("astrapy") -@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") -async def test_memory_with_message_store_async( - async_history1: AstraDBChatMessageHistory, -) -> None: - """Test the memory with a message store.""" - memory = ConversationBufferMemory( - memory_key="baz", - chat_memory=async_history1, - return_messages=True, - ) - - assert await memory.chat_memory.aget_messages() == [] - - # add some messages - await memory.chat_memory.aadd_messages( - [ - AIMessage(content="This is me, the AI"), - HumanMessage(content="This is me, the human"), - ] - ) - - messages = await memory.chat_memory.aget_messages() - expected = [ - AIMessage(content="This is me, the AI"), - HumanMessage(content="This is me, the human"), - ] - assert messages == expected - - # clear the store - await memory.chat_memory.aclear() - - assert await memory.chat_memory.aget_messages() == [] - - -@pytest.mark.requires("astrapy") -@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") -def test_memory_separate_session_ids( - history1: AstraDBChatMessageHistory, history2: AstraDBChatMessageHistory -) -> None: - """Test that separate session IDs do not share entries.""" - memory1 = ConversationBufferMemory( - memory_key="mk1", - chat_memory=history1, - return_messages=True, - ) - memory2 = ConversationBufferMemory( - memory_key="mk2", - chat_memory=history2, - return_messages=True, - ) - - memory1.chat_memory.add_messages([AIMessage(content="Just saying.")]) - - assert memory2.chat_memory.messages == [] - - memory2.chat_memory.clear() - - assert memory1.chat_memory.messages != [] - - memory1.chat_memory.clear() - - assert memory1.chat_memory.messages == [] - - -@pytest.mark.requires("astrapy") -@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") -async def test_memory_separate_session_ids_async( - async_history1: AstraDBChatMessageHistory, async_history2: AstraDBChatMessageHistory -) -> None: - """Test that separate session IDs do not share entries.""" - memory1 = ConversationBufferMemory( - memory_key="mk1", - chat_memory=async_history1, - return_messages=True, - ) - memory2 = ConversationBufferMemory( - memory_key="mk2", - chat_memory=async_history2, - return_messages=True, - ) - - await memory1.chat_memory.aadd_messages([AIMessage(content="Just saying.")]) - - assert await memory2.chat_memory.aget_messages() == [] - - await memory2.chat_memory.aclear() - - assert await memory1.chat_memory.aget_messages() != [] - - await memory1.chat_memory.aclear() - - assert await memory1.chat_memory.aget_messages() == [] diff --git a/libs/community/tests/integration_tests/memory/test_mongodb.py b/libs/community/tests/integration_tests/memory/test_mongodb.py deleted file mode 100644 index 3c13fb9d5..000000000 --- a/libs/community/tests/integration_tests/memory/test_mongodb.py +++ /dev/null @@ -1,37 +0,0 @@ -import json -import os - -from langchain_classic.memory import ConversationBufferMemory -from langchain_core.messages import message_to_dict - -from langchain_community.chat_message_histories import MongoDBChatMessageHistory - -# Replace these with your mongodb connection string -connection_string = os.environ.get("MONGODB_CONNECTION_STRING", "") - - -def test_memory_with_message_store() -> None: - """Test the memory with a message store.""" - # setup MongoDB as a message store - message_history = MongoDBChatMessageHistory( - connection_string=connection_string, session_id="test-session" - ) - memory = ConversationBufferMemory( - memory_key="baz", chat_memory=message_history, return_messages=True - ) - - # add some messages - memory.chat_memory.add_ai_message("This is me, the AI") - memory.chat_memory.add_user_message("This is me, the human") - - # get the message history from the memory store and turn it into a json - messages = memory.chat_memory.messages - messages_json = json.dumps([message_to_dict(msg) for msg in messages]) - - assert "This is me, the AI" in messages_json - assert "This is me, the human" in messages_json - - # remove the record from Azure Cosmos DB, so the next test run won't pick it up - memory.chat_memory.clear() - - assert memory.chat_memory.messages == [] diff --git a/libs/community/tests/integration_tests/memory/test_neo4j.py b/libs/community/tests/integration_tests/memory/test_neo4j.py deleted file mode 100644 index cc0a8a119..000000000 --- a/libs/community/tests/integration_tests/memory/test_neo4j.py +++ /dev/null @@ -1,31 +0,0 @@ -import json - -from langchain_classic.memory import ConversationBufferMemory -from langchain_core.messages import message_to_dict - -from langchain_community.chat_message_histories import Neo4jChatMessageHistory - - -def test_memory_with_message_store() -> None: - """Test the memory with a message store.""" - # setup MongoDB as a message store - message_history = Neo4jChatMessageHistory(session_id="test-session") - memory = ConversationBufferMemory( - memory_key="baz", chat_memory=message_history, return_messages=True - ) - - # add some messages - memory.chat_memory.add_ai_message("This is me, the AI") - memory.chat_memory.add_user_message("This is me, the human") - - # get the message history from the memory store and turn it into a json - messages = memory.chat_memory.messages - messages_json = json.dumps([message_to_dict(msg) for msg in messages]) - - assert "This is me, the AI" in messages_json - assert "This is me, the human" in messages_json - - # remove the record from Azure Cosmos DB, so the next test run won't pick it up - memory.chat_memory.clear() - - assert memory.chat_memory.messages == [] diff --git a/libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py b/libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py index 2f8bec891..dda7714f4 100644 --- a/libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py +++ b/libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py @@ -3,7 +3,7 @@ from langchain_classic.retrievers.document_compressors import LLMChainExtractor from langchain_core.documents import Document -from langchain_community.chat_models import ChatOpenAI +from langchain_community.chat_models.openai import ChatOpenAI def test_llm_chain_extractor() -> None: diff --git a/libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py b/libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py index 6ef6988a5..4a8145e4d 100644 --- a/libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py +++ b/libs/community/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py @@ -3,7 +3,7 @@ from langchain_classic.retrievers.document_compressors import LLMChainFilter from langchain_core.documents import Document -from langchain_community.chat_models import ChatOpenAI +from langchain_community.chat_models.openai import ChatOpenAI def test_llm_chain_filter() -> None: diff --git a/libs/community/tests/integration_tests/retrievers/test_qdrant_sparse_vector_retriever.py b/libs/community/tests/integration_tests/retrievers/test_qdrant_sparse_vector_retriever.py deleted file mode 100644 index 7afb7a771..000000000 --- a/libs/community/tests/integration_tests/retrievers/test_qdrant_sparse_vector_retriever.py +++ /dev/null @@ -1,170 +0,0 @@ -import random -import uuid -from typing import List, Tuple - -import pytest -from langchain_core.documents import Document - -from langchain_community.retrievers import QdrantSparseVectorRetriever -from langchain_community.vectorstores.qdrant import QdrantException - - -def consistent_fake_sparse_encoder( - query: str, size: int = 100, density: float = 0.7 -) -> Tuple[List[int], List[float]]: - """ - Generates a consistent fake sparse vector. - - Parameters: - - query (str): The query string to make the function deterministic. - - size (int): The size of the vector to generate. - - density (float): The density of the vector to generate. - - Returns: - - indices (list): List of indices where the non-zero elements are located. - - values (list): List of corresponding float values at the non-zero indices. - """ - # Ensure density is within the valid range [0, 1] - density = max(0.0, min(1.0, density)) - - # Use a deterministic seed based on the query - seed = hash(query) - random.seed(seed) - - # Calculate the number of non-zero elements based on density - num_non_zero_elements = int(size * density) - - # Generate random indices without replacement - indices = sorted(random.sample(range(size), num_non_zero_elements)) - - # Generate random float values for the non-zero elements - values = [random.uniform(0.0, 1.0) for _ in range(num_non_zero_elements)] - - return indices, values - - -@pytest.fixture -def retriever() -> QdrantSparseVectorRetriever: - from qdrant_client import QdrantClient, models - - client = QdrantClient(location=":memory:") - - collection_name = uuid.uuid4().hex - vector_name = uuid.uuid4().hex - - client.recreate_collection( - collection_name, - vectors_config={}, - sparse_vectors_config={ - vector_name: models.SparseVectorParams( - index=models.SparseIndexParams( - on_disk=False, - ) - ) - }, - ) - - return QdrantSparseVectorRetriever( - client=client, - collection_name=collection_name, - sparse_vector_name=vector_name, - sparse_encoder=consistent_fake_sparse_encoder, - ) - - -def test_invalid_collection_name(retriever: QdrantSparseVectorRetriever) -> None: - with pytest.raises(QdrantException) as e: - QdrantSparseVectorRetriever( - client=retriever.client, - collection_name="invalid collection", - sparse_vector_name=retriever.sparse_vector_name, - sparse_encoder=consistent_fake_sparse_encoder, - ) - assert "does not exist" in str(e.value) - - -def test_invalid_sparse_vector_name(retriever: QdrantSparseVectorRetriever) -> None: - with pytest.raises(QdrantException) as e: - QdrantSparseVectorRetriever( - client=retriever.client, - collection_name=retriever.collection_name, - sparse_vector_name="invalid sparse vector", - sparse_encoder=consistent_fake_sparse_encoder, - ) - - assert "does not contain sparse vector" in str(e.value) - - -def test_add_documents(retriever: QdrantSparseVectorRetriever) -> None: - documents = [ - Document(page_content="hello world", metadata={"a": 1}), - Document(page_content="foo bar", metadata={"b": 2}), - Document(page_content="baz qux", metadata={"c": 3}), - ] - - ids = retriever.add_documents(documents) - - assert retriever.client.count(retriever.collection_name, exact=True).count == 3 - - documents = [ - Document(page_content="hello world"), - Document(page_content="foo bar"), - Document(page_content="baz qux"), - ] - - ids = retriever.add_documents(documents) - - assert len(ids) == 3 - - assert retriever.client.count(retriever.collection_name, exact=True).count == 6 - - -def test_add_texts(retriever: QdrantSparseVectorRetriever) -> None: - retriever.add_texts( - ["hello world", "foo bar", "baz qux"], [{"a": 1}, {"b": 2}, {"c": 3}] - ) - - assert retriever.client.count(retriever.collection_name, exact=True).count == 3 - - retriever.add_texts(["hello world", "foo bar", "baz qux"]) - - assert retriever.client.count(retriever.collection_name, exact=True).count == 6 - - -def test_invoke(retriever: QdrantSparseVectorRetriever) -> None: - retriever.add_texts(["Hai there!", "Hello world!", "Foo bar baz!"]) - - expected = [Document(page_content="Hai there!")] - - retriever.k = 1 - results = retriever.invoke("Hai there!") - - assert len(results) == retriever.k - assert results == expected - assert retriever.invoke("Hai there!") == expected - - -def test_invoke_with_filter( - retriever: QdrantSparseVectorRetriever, -) -> None: - from qdrant_client import models - - retriever.add_texts( - ["Hai there!", "Hello world!", "Foo bar baz!"], - [ - {"value": 1}, - {"value": 2}, - {"value": 3}, - ], - ) - - retriever.filter = models.Filter( - must=[ - models.FieldCondition( - key="metadata.value", match=models.MatchValue(value=2) - ) - ] - ) - results = retriever.invoke("Some query") - - assert results[0] == Document(page_content="Hello world!", metadata={"value": 2}) diff --git a/libs/community/tests/integration_tests/smith/evaluation/test_runner_utils.py b/libs/community/tests/integration_tests/smith/evaluation/test_runner_utils.py index bc03e0a45..e7b28b4b2 100644 --- a/libs/community/tests/integration_tests/smith/evaluation/test_runner_utils.py +++ b/libs/community/tests/integration_tests/smith/evaluation/test_runner_utils.py @@ -13,7 +13,7 @@ from langsmith.evaluation import run_evaluator from langsmith.schemas import DataType, Example, Run -from langchain_community.chat_models import ChatOpenAI +from langchain_community.chat_models.openai import ChatOpenAI from langchain_community.llms.openai import OpenAI diff --git a/libs/community/tests/integration_tests/utilities/test_googlesearch_api.py b/libs/community/tests/integration_tests/utilities/test_googlesearch_api.py deleted file mode 100644 index b32d05e76..000000000 --- a/libs/community/tests/integration_tests/utilities/test_googlesearch_api.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Integration test for Google Search API Wrapper.""" - -from langchain_community.utilities.google_search import GoogleSearchAPIWrapper - - -def test_call() -> None: - """Test that call gives the correct answer.""" - search = GoogleSearchAPIWrapper() - output = search.run("What was Obama's first name?") - assert "Barack Hussein Obama II" in output - - -def test_no_result_call() -> None: - """Test that call gives no result.""" - search = GoogleSearchAPIWrapper() - output = search.run( - "NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL" - ) - print(type(output)) # noqa: T201 - assert "No good Google Search Result was found" == output - - -def test_result_with_params_call() -> None: - """Test that call gives the correct answer with extra params.""" - search = GoogleSearchAPIWrapper() - output = search.results( - query="What was Obama's first name?", - num_results=5, - search_params={"cr": "us", "safe": "active"}, - ) - assert len(output) diff --git a/libs/community/tests/integration_tests/vectorstores/conftest.py b/libs/community/tests/integration_tests/vectorstores/conftest.py index a2fc90531..30b898a00 100644 --- a/libs/community/tests/integration_tests/vectorstores/conftest.py +++ b/libs/community/tests/integration_tests/vectorstores/conftest.py @@ -1,16 +1,8 @@ -import os from typing import Union import pytest from vcr.request import Request -# Those environment variables turn on Deep Lake pytest mode. -# It significantly makes tests run much faster. -# Need to run before `import deeplake` -os.environ["BUGGER_OFF"] = "true" -os.environ["DEEPLAKE_DOWNLOAD_PATH"] = "./testing/local_storage" -os.environ["DEEPLAKE_PYTEST_ENABLED"] = "true" - # This fixture returns a dictionary containing filter_headers options # for replacing certain headers with dummy values during cassette playback diff --git a/libs/community/tests/integration_tests/vectorstores/docker-compose/weaviate.yml b/libs/community/tests/integration_tests/vectorstores/docker-compose/weaviate.yml deleted file mode 100644 index e270c71c1..000000000 --- a/libs/community/tests/integration_tests/vectorstores/docker-compose/weaviate.yml +++ /dev/null @@ -1,23 +0,0 @@ -version: '3.4' - -services: - weaviate: - command: - - --host - - 0.0.0.0 - - --port - - '8080' - - --scheme - - http - image: semitechnologies/weaviate:1.18.2 - ports: - - 8080:8080 - restart: on-failure:0 - environment: - QUERY_DEFAULTS_LIMIT: 25 - AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' - PERSISTENCE_DATA_PATH: '/var/lib/weaviate' - DEFAULT_VECTORIZER_MODULE: 'text2vec-openai' - ENABLE_MODULES: 'text2vec-openai' - OPENAI_APIKEY: '${OPENAI_API_KEY}' - CLUSTER_HOSTNAME: 'node1' diff --git a/libs/community/tests/integration_tests/vectorstores/qdrant/__init__.py b/libs/community/tests/integration_tests/vectorstores/qdrant/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/libs/community/tests/integration_tests/vectorstores/qdrant/async_api/__init__.py b/libs/community/tests/integration_tests/vectorstores/qdrant/async_api/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/libs/community/tests/integration_tests/vectorstores/qdrant/async_api/fixtures.py b/libs/community/tests/integration_tests/vectorstores/qdrant/async_api/fixtures.py deleted file mode 100644 index 2728f4f68..000000000 --- a/libs/community/tests/integration_tests/vectorstores/qdrant/async_api/fixtures.py +++ /dev/null @@ -1,13 +0,0 @@ -import logging -from typing import List - -from tests.integration_tests.vectorstores.qdrant.common import qdrant_is_not_running - -logger = logging.getLogger(__name__) - - -def qdrant_locations() -> List[str]: - if qdrant_is_not_running(): - logger.warning("Running Qdrant async tests in memory mode only.") - return [":memory:"] - return ["http://localhost:6333", ":memory:"] diff --git a/libs/community/tests/integration_tests/vectorstores/qdrant/async_api/test_add_texts.py b/libs/community/tests/integration_tests/vectorstores/qdrant/async_api/test_add_texts.py deleted file mode 100644 index 83ff601bc..000000000 --- a/libs/community/tests/integration_tests/vectorstores/qdrant/async_api/test_add_texts.py +++ /dev/null @@ -1,124 +0,0 @@ -import uuid -from typing import Optional - -import pytest - -from langchain_community.vectorstores import Qdrant -from tests.integration_tests.vectorstores.fake_embeddings import ( - ConsistentFakeEmbeddings, -) -from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import ( - qdrant_locations, -) - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_qdrant_aadd_texts_returns_all_ids( - batch_size: int, qdrant_location: str -) -> None: - """Test end to end Qdrant.aadd_texts returns unique ids.""" - docsearch: Qdrant = Qdrant.from_texts( - ["foobar"], - ConsistentFakeEmbeddings(), - batch_size=batch_size, - location=qdrant_location, - ) - - ids = await docsearch.aadd_texts(["foo", "bar", "baz"]) - assert 3 == len(ids) - assert 3 == len(set(ids)) - - -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_qdrant_aadd_texts_stores_duplicated_texts( - vector_name: Optional[str], qdrant_location: str -) -> None: - """Test end to end Qdrant.aadd_texts stores duplicated texts separately.""" - from qdrant_client import QdrantClient - from qdrant_client.http import models as rest - - client = QdrantClient(location=qdrant_location) - collection_name = uuid.uuid4().hex - vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE) - if vector_name is not None: - vectors_config = {vector_name: vectors_config} - client.recreate_collection(collection_name, vectors_config=vectors_config) - - vec_store = Qdrant( - client, - collection_name, - embeddings=ConsistentFakeEmbeddings(), - vector_name=vector_name, - ) - ids = await vec_store.aadd_texts(["abc", "abc"], [{"a": 1}, {"a": 2}]) - - assert 2 == len(set(ids)) - assert 2 == client.count(collection_name).count - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_qdrant_aadd_texts_stores_ids( - batch_size: int, qdrant_location: str -) -> None: - """Test end to end Qdrant.aadd_texts stores provided ids.""" - from qdrant_client import QdrantClient - from qdrant_client.http import models as rest - - ids = [ - "fa38d572-4c31-4579-aedc-1960d79df6df", - "cdc1aa36-d6ab-4fb2-8a94-56674fd27484", - ] - - client = QdrantClient(location=qdrant_location) - collection_name = uuid.uuid4().hex - client.recreate_collection( - collection_name, - vectors_config=rest.VectorParams(size=10, distance=rest.Distance.COSINE), - ) - - vec_store = Qdrant(client, collection_name, ConsistentFakeEmbeddings()) - returned_ids = await vec_store.aadd_texts( - ["abc", "def"], ids=ids, batch_size=batch_size - ) - - assert all(first == second for first, second in zip(ids, returned_ids)) - assert 2 == client.count(collection_name).count - stored_ids = [point.id for point in client.scroll(collection_name)[0]] - assert set(ids) == set(stored_ids) - - -@pytest.mark.parametrize("vector_name", ["custom-vector"]) -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_qdrant_aadd_texts_stores_embeddings_as_named_vectors( - vector_name: str, qdrant_location: str -) -> None: - """Test end to end Qdrant.aadd_texts stores named vectors if name is provided.""" - from qdrant_client import QdrantClient - from qdrant_client.http import models as rest - - collection_name = uuid.uuid4().hex - - client = QdrantClient(location=qdrant_location) - client.recreate_collection( - collection_name, - vectors_config={ - vector_name: rest.VectorParams(size=10, distance=rest.Distance.COSINE) - }, - ) - - vec_store = Qdrant( - client, - collection_name, - ConsistentFakeEmbeddings(), - vector_name=vector_name, - ) - await vec_store.aadd_texts(["lorem", "ipsum", "dolor", "sit", "amet"]) - - assert 5 == client.count(collection_name).count - assert all( - vector_name in point.vector - for point in client.scroll(collection_name, with_vectors=True)[0] - ) diff --git a/libs/community/tests/integration_tests/vectorstores/qdrant/async_api/test_from_texts.py b/libs/community/tests/integration_tests/vectorstores/qdrant/async_api/test_from_texts.py deleted file mode 100644 index bb787a7a2..000000000 --- a/libs/community/tests/integration_tests/vectorstores/qdrant/async_api/test_from_texts.py +++ /dev/null @@ -1,253 +0,0 @@ -import uuid -from typing import Optional - -import pytest -from langchain_core.documents import Document - -from langchain_community.vectorstores import Qdrant -from langchain_community.vectorstores.qdrant import QdrantException -from tests.integration_tests.vectorstores.fake_embeddings import ( - ConsistentFakeEmbeddings, -) -from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import ( - qdrant_locations, -) -from tests.integration_tests.vectorstores.qdrant.common import ( - assert_documents_equals, - qdrant_is_not_running, -) - - -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_qdrant_from_texts_stores_duplicated_texts(qdrant_location: str) -> None: - """Test end to end Qdrant.afrom_texts stores duplicated texts separately.""" - collection_name = uuid.uuid4().hex - - vec_store = await Qdrant.afrom_texts( - ["abc", "abc"], - ConsistentFakeEmbeddings(), - collection_name=collection_name, - location=qdrant_location, - ) - - client = vec_store.client - assert 2 == client.count(collection_name).count - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_qdrant_from_texts_stores_ids( - batch_size: int, vector_name: Optional[str], qdrant_location: str -) -> None: - """Test end to end Qdrant.afrom_texts stores provided ids.""" - collection_name = uuid.uuid4().hex - ids = [ - "fa38d572-4c31-4579-aedc-1960d79df6df", - "cdc1aa36-d6ab-4fb2-8a94-56674fd27484", - ] - vec_store = await Qdrant.afrom_texts( - ["abc", "def"], - ConsistentFakeEmbeddings(), - ids=ids, - collection_name=collection_name, - batch_size=batch_size, - vector_name=vector_name, - location=qdrant_location, - ) - - client = vec_store.client - assert 2 == client.count(collection_name).count - stored_ids = [point.id for point in client.scroll(collection_name)[0]] - assert set(ids) == set(stored_ids) - - -@pytest.mark.parametrize("vector_name", ["custom-vector"]) -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_qdrant_from_texts_stores_embeddings_as_named_vectors( - vector_name: str, - qdrant_location: str, -) -> None: - """Test end to end Qdrant.afrom_texts stores named vectors if name is provided.""" - collection_name = uuid.uuid4().hex - - vec_store = await Qdrant.afrom_texts( - ["lorem", "ipsum", "dolor", "sit", "amet"], - ConsistentFakeEmbeddings(), - collection_name=collection_name, - vector_name=vector_name, - location=qdrant_location, - ) - - client = vec_store.client - assert 5 == client.count(collection_name).count - assert all( - vector_name in point.vector - for point in client.scroll(collection_name, with_vectors=True)[0] - ) - - -@pytest.mark.parametrize("vector_name", [None, "custom-vector"]) -@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running") -async def test_qdrant_from_texts_reuses_same_collection( - vector_name: Optional[str], -) -> None: - """Test if Qdrant.afrom_texts reuses the same collection""" - collection_name = uuid.uuid4().hex - embeddings = ConsistentFakeEmbeddings() - - await Qdrant.afrom_texts( - ["lorem", "ipsum", "dolor", "sit", "amet"], - embeddings, - collection_name=collection_name, - vector_name=vector_name, - ) - - vec_store = await Qdrant.afrom_texts( - ["foo", "bar"], - embeddings, - collection_name=collection_name, - vector_name=vector_name, - ) - - client = vec_store.client - assert 7 == client.count(collection_name).count - - -@pytest.mark.parametrize("vector_name", [None, "custom-vector"]) -@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running") -async def test_qdrant_from_texts_raises_error_on_different_dimensionality( - vector_name: Optional[str], -) -> None: - """Test if Qdrant.afrom_texts raises an exception if dimensionality does not - match""" - collection_name = uuid.uuid4().hex - - await Qdrant.afrom_texts( - ["lorem", "ipsum", "dolor", "sit", "amet"], - ConsistentFakeEmbeddings(dimensionality=10), - collection_name=collection_name, - vector_name=vector_name, - ) - - with pytest.raises(QdrantException): - await Qdrant.afrom_texts( - ["foo", "bar"], - ConsistentFakeEmbeddings(dimensionality=5), - collection_name=collection_name, - vector_name=vector_name, - ) - - -@pytest.mark.parametrize( - ["first_vector_name", "second_vector_name"], - [ - (None, "custom-vector"), - ("custom-vector", None), - ("my-first-vector", "my-second_vector"), - ], -) -@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running") -async def test_qdrant_from_texts_raises_error_on_different_vector_name( - first_vector_name: Optional[str], - second_vector_name: Optional[str], -) -> None: - """Test if Qdrant.afrom_texts raises an exception if vector name does not match""" - collection_name = uuid.uuid4().hex - - await Qdrant.afrom_texts( - ["lorem", "ipsum", "dolor", "sit", "amet"], - ConsistentFakeEmbeddings(dimensionality=10), - collection_name=collection_name, - vector_name=first_vector_name, - ) - - with pytest.raises(QdrantException): - await Qdrant.afrom_texts( - ["foo", "bar"], - ConsistentFakeEmbeddings(dimensionality=5), - collection_name=collection_name, - vector_name=second_vector_name, - ) - - -@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running") -async def test_qdrant_from_texts_raises_error_on_different_distance() -> None: - """Test if Qdrant.afrom_texts raises an exception if distance does not match""" - collection_name = uuid.uuid4().hex - - await Qdrant.afrom_texts( - ["lorem", "ipsum", "dolor", "sit", "amet"], - ConsistentFakeEmbeddings(dimensionality=10), - collection_name=collection_name, - distance_func="Cosine", - ) - - with pytest.raises(QdrantException): - await Qdrant.afrom_texts( - ["foo", "bar"], - ConsistentFakeEmbeddings(dimensionality=5), - collection_name=collection_name, - distance_func="Euclid", - ) - - -@pytest.mark.parametrize("vector_name", [None, "custom-vector"]) -@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running") -async def test_qdrant_from_texts_recreates_collection_on_force_recreate( - vector_name: Optional[str], -) -> None: - """Test if Qdrant.afrom_texts recreates the collection even if config mismatches""" - from qdrant_client import QdrantClient - - collection_name = uuid.uuid4().hex - - await Qdrant.afrom_texts( - ["lorem", "ipsum", "dolor", "sit", "amet"], - ConsistentFakeEmbeddings(dimensionality=10), - collection_name=collection_name, - vector_name=vector_name, - ) - - await Qdrant.afrom_texts( - ["foo", "bar"], - ConsistentFakeEmbeddings(dimensionality=5), - collection_name=collection_name, - vector_name=vector_name, - force_recreate=True, - ) - - client = QdrantClient() - assert 2 == client.count(collection_name).count - vector_params = client.get_collection(collection_name).config.params.vectors - if vector_name is not None: - vector_params = vector_params[vector_name] - assert 5 == vector_params.size - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) -@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_qdrant_from_texts_stores_metadatas( - batch_size: int, - content_payload_key: str, - metadata_payload_key: str, - qdrant_location: str, -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = await Qdrant.afrom_texts( - texts, - ConsistentFakeEmbeddings(), - metadatas=metadatas, - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - batch_size=batch_size, - location=qdrant_location, - ) - output = await docsearch.asimilarity_search("foo", k=1) - assert_documents_equals( - output, [Document(page_content="foo", metadata={"page": 0})] - ) diff --git a/libs/community/tests/integration_tests/vectorstores/qdrant/async_api/test_max_marginal_relevance.py b/libs/community/tests/integration_tests/vectorstores/qdrant/async_api/test_max_marginal_relevance.py deleted file mode 100644 index 472a66bb0..000000000 --- a/libs/community/tests/integration_tests/vectorstores/qdrant/async_api/test_max_marginal_relevance.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Optional - -import pytest -from langchain_core.documents import Document - -from langchain_community.vectorstores import Qdrant -from tests.integration_tests.vectorstores.fake_embeddings import ( - ConsistentFakeEmbeddings, -) -from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import ( - qdrant_locations, -) -from tests.integration_tests.vectorstores.qdrant.common import assert_documents_equals - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"]) -@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"]) -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_qdrant_max_marginal_relevance_search( - batch_size: int, - content_payload_key: str, - metadata_payload_key: str, - vector_name: Optional[str], - qdrant_location: str, -) -> None: - """Test end to end construction and MRR search.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - metadatas=metadatas, - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - batch_size=batch_size, - vector_name=vector_name, - location=qdrant_location, - distance_func="EUCLID", # Euclid distance used to avoid normalization - ) - output = await docsearch.amax_marginal_relevance_search( - "foo", k=2, fetch_k=3, lambda_mult=0.0 - ) - assert_documents_equals( - output, - [ - Document(page_content="foo", metadata={"page": 0}), - Document(page_content="baz", metadata={"page": 2}), - ], - ) diff --git a/libs/community/tests/integration_tests/vectorstores/qdrant/async_api/test_similarity_search.py b/libs/community/tests/integration_tests/vectorstores/qdrant/async_api/test_similarity_search.py deleted file mode 100644 index 5ae98ad35..000000000 --- a/libs/community/tests/integration_tests/vectorstores/qdrant/async_api/test_similarity_search.py +++ /dev/null @@ -1,307 +0,0 @@ -from typing import Optional - -import numpy as np -import pytest -from langchain_core.documents import Document - -from langchain_community.vectorstores import Qdrant -from tests.integration_tests.vectorstores.fake_embeddings import ( - ConsistentFakeEmbeddings, -) -from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import ( - qdrant_locations, -) -from tests.integration_tests.vectorstores.qdrant.common import assert_documents_equals - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) -@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_qdrant_similarity_search( - batch_size: int, - content_payload_key: str, - metadata_payload_key: str, - vector_name: Optional[str], - qdrant_location: str, -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - batch_size=batch_size, - vector_name=vector_name, - location=qdrant_location, - ) - output = await docsearch.asimilarity_search("foo", k=1) - assert_documents_equals(output, [Document(page_content="foo")]) - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) -@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_qdrant_similarity_search_by_vector( - batch_size: int, - content_payload_key: str, - metadata_payload_key: str, - vector_name: Optional[str], - qdrant_location: str, -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - batch_size=batch_size, - vector_name=vector_name, - location=qdrant_location, - ) - embeddings = ConsistentFakeEmbeddings().embed_query("foo") - output = await docsearch.asimilarity_search_by_vector(embeddings, k=1) - assert_documents_equals(output, [Document(page_content="foo")]) - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) -@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_qdrant_similarity_search_with_score_by_vector( - batch_size: int, - content_payload_key: str, - metadata_payload_key: str, - vector_name: Optional[str], - qdrant_location: str, -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - batch_size=batch_size, - vector_name=vector_name, - location=qdrant_location, - ) - embeddings = ConsistentFakeEmbeddings().embed_query("foo") - output = await docsearch.asimilarity_search_with_score_by_vector(embeddings, k=1) - assert len(output) == 1 - document, score = output[0] - assert_documents_equals([document], [Document(page_content="foo")]) - assert score >= 0 - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_qdrant_similarity_search_filters( - batch_size: int, vector_name: Optional[str], qdrant_location: str -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - metadatas = [ - {"page": i, "metadata": {"page": i + 1, "pages": [i + 2, -1]}} - for i in range(len(texts)) - ] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - metadatas=metadatas, - batch_size=batch_size, - vector_name=vector_name, - location=qdrant_location, - ) - - output = await docsearch.asimilarity_search( - "foo", k=1, filter={"page": 1, "metadata": {"page": 2, "pages": [3]}} - ) - assert_documents_equals( - output, - [ - Document( - page_content="bar", - metadata={"page": 1, "metadata": {"page": 2, "pages": [3, -1]}}, - ) - ], - ) - - -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_qdrant_similarity_search_with_relevance_score_no_threshold( - vector_name: Optional[str], - qdrant_location: str, -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - metadatas = [ - {"page": i, "metadata": {"page": i + 1, "pages": [i + 2, -1]}} - for i in range(len(texts)) - ] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - metadatas=metadatas, - vector_name=vector_name, - location=qdrant_location, - ) - output = await docsearch.asimilarity_search_with_relevance_scores( - "foo", k=3, score_threshold=None - ) - assert len(output) == 3 - for i in range(len(output)): - assert round(output[i][1], 2) >= 0 - assert round(output[i][1], 2) <= 1 - - -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_qdrant_similarity_search_with_relevance_score_with_threshold( - vector_name: Optional[str], - qdrant_location: str, -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - metadatas = [ - {"page": i, "metadata": {"page": i + 1, "pages": [i + 2, -1]}} - for i in range(len(texts)) - ] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - metadatas=metadatas, - vector_name=vector_name, - location=qdrant_location, - ) - - score_threshold = 0.98 - kwargs = {"score_threshold": score_threshold} - output = await docsearch.asimilarity_search_with_relevance_scores( - "foo", k=3, **kwargs - ) - assert len(output) == 1 - assert all([score >= score_threshold for _, score in output]) - - -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_similarity_search_with_relevance_score_with_threshold_and_filter( - vector_name: Optional[str], - qdrant_location: str, -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - metadatas = [ - {"page": i, "metadata": {"page": i + 1, "pages": [i + 2, -1]}} - for i in range(len(texts)) - ] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - metadatas=metadatas, - vector_name=vector_name, - location=qdrant_location, - ) - score_threshold = 0.99 # for almost exact match - # test negative filter condition - negative_filter = {"page": 1, "metadata": {"page": 2, "pages": [3]}} - kwargs = {"filter": negative_filter, "score_threshold": score_threshold} - output = docsearch.similarity_search_with_relevance_scores("foo", k=3, **kwargs) - assert len(output) == 0 - # test positive filter condition - positive_filter = {"page": 0, "metadata": {"page": 1, "pages": [2]}} - kwargs = {"filter": positive_filter, "score_threshold": score_threshold} - output = await docsearch.asimilarity_search_with_relevance_scores( - "foo", k=3, **kwargs - ) - assert len(output) == 1 - assert all([score >= score_threshold for _, score in output]) - - -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_qdrant_similarity_search_filters_with_qdrant_filters( - vector_name: Optional[str], - qdrant_location: str, -) -> None: - """Test end to end construction and search.""" - from qdrant_client.http import models as rest - - texts = ["foo", "bar", "baz"] - metadatas = [ - {"page": i, "details": {"page": i + 1, "pages": [i + 2, -1]}} - for i in range(len(texts)) - ] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - metadatas=metadatas, - vector_name=vector_name, - location=qdrant_location, - ) - - qdrant_filter = rest.Filter( - must=[ - rest.FieldCondition( - key="metadata.page", - match=rest.MatchValue(value=1), - ), - rest.FieldCondition( - key="metadata.details.page", - match=rest.MatchValue(value=2), - ), - rest.FieldCondition( - key="metadata.details.pages", - match=rest.MatchAny(any=[3]), - ), - ] - ) - output = await docsearch.asimilarity_search("foo", k=1, filter=qdrant_filter) - assert_documents_equals( - output, - [ - Document( - page_content="bar", - metadata={"page": 1, "details": {"page": 2, "pages": [3, -1]}}, - ) - ], - ) - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) -@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -@pytest.mark.parametrize("qdrant_location", qdrant_locations()) -async def test_qdrant_similarity_search_with_relevance_scores( - batch_size: int, - content_payload_key: str, - metadata_payload_key: str, - vector_name: str, - qdrant_location: str, -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - batch_size=batch_size, - vector_name=vector_name, - location=qdrant_location, - ) - output = await docsearch.asimilarity_search_with_relevance_scores("foo", k=3) - - assert all( - (1 >= score or np.isclose(score, 1)) and score >= 0 for _, score in output - ) diff --git a/libs/community/tests/integration_tests/vectorstores/qdrant/common.py b/libs/community/tests/integration_tests/vectorstores/qdrant/common.py deleted file mode 100644 index a86519715..000000000 --- a/libs/community/tests/integration_tests/vectorstores/qdrant/common.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import List - -from langchain_core.documents import Document - - -def qdrant_is_not_running() -> bool: - """Check if Qdrant is not running.""" - import requests - - try: - response = requests.get("http://localhost:6333", timeout=10.0) - response_json = response.json() - return response_json.get("title") != "qdrant - vector search engine" - except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): - return True - - -def assert_documents_equals(actual: List[Document], expected: List[Document]) -> None: - assert len(actual) == len(expected) - - for actual_doc, expected_doc in zip(actual, expected): - assert actual_doc.page_content == expected_doc.page_content - - assert "_id" in actual_doc.metadata - assert "_collection_name" in actual_doc.metadata - - actual_doc.metadata.pop("_id") - actual_doc.metadata.pop("_collection_name") - - assert actual_doc.metadata == expected_doc.metadata diff --git a/libs/community/tests/integration_tests/vectorstores/qdrant/fixtures.py b/libs/community/tests/integration_tests/vectorstores/qdrant/fixtures.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/libs/community/tests/integration_tests/vectorstores/qdrant/test_add_texts.py b/libs/community/tests/integration_tests/vectorstores/qdrant/test_add_texts.py deleted file mode 100644 index bf579feb7..000000000 --- a/libs/community/tests/integration_tests/vectorstores/qdrant/test_add_texts.py +++ /dev/null @@ -1,135 +0,0 @@ -import uuid -from typing import Optional - -import pytest -from langchain_core.documents import Document - -from langchain_community.vectorstores import Qdrant -from tests.integration_tests.vectorstores.fake_embeddings import ( - ConsistentFakeEmbeddings, -) -from tests.integration_tests.vectorstores.qdrant.common import assert_documents_equals - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -def test_qdrant_add_documents_extends_existing_collection( - batch_size: int, vector_name: Optional[str] -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - docsearch: Qdrant = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - location=":memory:", - batch_size=batch_size, - vector_name=vector_name, - ) - - new_texts = ["foobar", "foobaz"] - docsearch.add_documents( - [Document(page_content=content) for content in new_texts], batch_size=batch_size - ) - output = docsearch.similarity_search("foobar", k=1) - # ConsistentFakeEmbeddings return the same query embedding as the first document - # embedding computed in `embedding.embed_documents`. Thus, "foo" embedding is the - # same as "foobar" embedding - assert_documents_equals(output, [Document(page_content="foobar")]) - - -@pytest.mark.parametrize("batch_size", [1, 64]) -def test_qdrant_add_texts_returns_all_ids(batch_size: int) -> None: - """Test end to end Qdrant.add_texts returns unique ids.""" - docsearch: Qdrant = Qdrant.from_texts( - ["foobar"], - ConsistentFakeEmbeddings(), - location=":memory:", - batch_size=batch_size, - ) - - ids = docsearch.add_texts(["foo", "bar", "baz"]) - assert 3 == len(ids) - assert 3 == len(set(ids)) - - -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -def test_qdrant_add_texts_stores_duplicated_texts(vector_name: Optional[str]) -> None: - """Test end to end Qdrant.add_texts stores duplicated texts separately.""" - from qdrant_client import QdrantClient - from qdrant_client.http import models as rest - - client = QdrantClient(":memory:") - collection_name = uuid.uuid4().hex - vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE) - if vector_name is not None: - vectors_config = {vector_name: vectors_config} - client.recreate_collection(collection_name, vectors_config=vectors_config) - - vec_store = Qdrant( - client, - collection_name, - embeddings=ConsistentFakeEmbeddings(), - vector_name=vector_name, - ) - ids = vec_store.add_texts(["abc", "abc"], [{"a": 1}, {"a": 2}]) - - assert 2 == len(set(ids)) - assert 2 == client.count(collection_name).count - - -@pytest.mark.parametrize("batch_size", [1, 64]) -def test_qdrant_add_texts_stores_ids(batch_size: int) -> None: - """Test end to end Qdrant.add_texts stores provided ids.""" - from qdrant_client import QdrantClient - from qdrant_client.http import models as rest - - ids = [ - "fa38d572-4c31-4579-aedc-1960d79df6df", - "cdc1aa36-d6ab-4fb2-8a94-56674fd27484", - ] - - client = QdrantClient(":memory:") - collection_name = uuid.uuid4().hex - client.recreate_collection( - collection_name, - vectors_config=rest.VectorParams(size=10, distance=rest.Distance.COSINE), - ) - - vec_store = Qdrant(client, collection_name, ConsistentFakeEmbeddings()) - returned_ids = vec_store.add_texts(["abc", "def"], ids=ids, batch_size=batch_size) - - assert all(first == second for first, second in zip(ids, returned_ids)) - assert 2 == client.count(collection_name).count - stored_ids = [point.id for point in client.scroll(collection_name)[0]] - assert set(ids) == set(stored_ids) - - -@pytest.mark.parametrize("vector_name", ["custom-vector"]) -def test_qdrant_add_texts_stores_embeddings_as_named_vectors(vector_name: str) -> None: - """Test end to end Qdrant.add_texts stores named vectors if name is provided.""" - from qdrant_client import QdrantClient - from qdrant_client.http import models as rest - - collection_name = uuid.uuid4().hex - - client = QdrantClient(":memory:") - client.recreate_collection( - collection_name, - vectors_config={ - vector_name: rest.VectorParams(size=10, distance=rest.Distance.COSINE) - }, - ) - - vec_store = Qdrant( - client, - collection_name, - ConsistentFakeEmbeddings(), - vector_name=vector_name, - ) - vec_store.add_texts(["lorem", "ipsum", "dolor", "sit", "amet"]) - - assert 5 == client.count(collection_name).count - assert all( - vector_name in point.vector - for point in client.scroll(collection_name, with_vectors=True)[0] - ) diff --git a/libs/community/tests/integration_tests/vectorstores/qdrant/test_delete.py b/libs/community/tests/integration_tests/vectorstores/qdrant/test_delete.py deleted file mode 100644 index 6804fcf88..000000000 --- a/libs/community/tests/integration_tests/vectorstores/qdrant/test_delete.py +++ /dev/null @@ -1 +0,0 @@ -# TODO: implement tests for delete diff --git a/libs/community/tests/integration_tests/vectorstores/qdrant/test_embedding_interface.py b/libs/community/tests/integration_tests/vectorstores/qdrant/test_embedding_interface.py deleted file mode 100644 index fc3d2d0ec..000000000 --- a/libs/community/tests/integration_tests/vectorstores/qdrant/test_embedding_interface.py +++ /dev/null @@ -1,60 +0,0 @@ -import uuid -from typing import Callable, Optional - -import pytest -from langchain_core.embeddings import Embeddings - -from langchain_community.vectorstores import Qdrant -from tests.integration_tests.vectorstores.fake_embeddings import ( - ConsistentFakeEmbeddings, -) - - -@pytest.mark.parametrize( - ["embeddings", "embedding_function"], - [ - (ConsistentFakeEmbeddings(), None), - (ConsistentFakeEmbeddings().embed_query, None), - (None, ConsistentFakeEmbeddings().embed_query), - ], -) -def test_qdrant_embedding_interface( - embeddings: Optional[Embeddings], embedding_function: Optional[Callable] -) -> None: - """Test Qdrant may accept different types for embeddings.""" - from qdrant_client import QdrantClient - - client = QdrantClient(":memory:") - collection_name = uuid.uuid4().hex - - Qdrant( - client, - collection_name, - embeddings=embeddings, - embedding_function=embedding_function, - ) - - -@pytest.mark.parametrize( - ["embeddings", "embedding_function"], - [ - (ConsistentFakeEmbeddings(), ConsistentFakeEmbeddings().embed_query), - (None, None), - ], -) -def test_qdrant_embedding_interface_raises_value_error( - embeddings: Optional[Embeddings], embedding_function: Optional[Callable] -) -> None: - """Test Qdrant requires only one method for embeddings.""" - from qdrant_client import QdrantClient - - client = QdrantClient(":memory:") - collection_name = uuid.uuid4().hex - - with pytest.raises(ValueError): - Qdrant( - client, - collection_name, - embeddings=embeddings, - embedding_function=embedding_function, - ) diff --git a/libs/community/tests/integration_tests/vectorstores/qdrant/test_from_existing_collection.py b/libs/community/tests/integration_tests/vectorstores/qdrant/test_from_existing_collection.py deleted file mode 100644 index 04a09c69f..000000000 --- a/libs/community/tests/integration_tests/vectorstores/qdrant/test_from_existing_collection.py +++ /dev/null @@ -1,39 +0,0 @@ -import tempfile -import uuid - -import pytest - -from langchain_community.vectorstores import Qdrant -from tests.integration_tests.vectorstores.fake_embeddings import ( - ConsistentFakeEmbeddings, -) - - -@pytest.mark.parametrize("vector_name", ["custom-vector"]) -def test_qdrant_from_existing_collection_uses_same_collection(vector_name: str) -> None: - """Test if the Qdrant.from_existing_collection reuses the same collection.""" - from qdrant_client import QdrantClient - - collection_name = uuid.uuid4().hex - with tempfile.TemporaryDirectory() as tmpdir: - docs = ["foo"] - qdrant = Qdrant.from_texts( - docs, - embedding=ConsistentFakeEmbeddings(), - path=str(tmpdir), - collection_name=collection_name, - vector_name=vector_name, - ) - del qdrant - - qdrant = Qdrant.from_existing_collection( - embedding=ConsistentFakeEmbeddings(), - path=str(tmpdir), - collection_name=collection_name, - vector_name=vector_name, - ) - qdrant.add_texts(["baz", "bar"]) - del qdrant - - client = QdrantClient(path=str(tmpdir)) - assert 3 == client.count(collection_name).count diff --git a/libs/community/tests/integration_tests/vectorstores/qdrant/test_from_texts.py b/libs/community/tests/integration_tests/vectorstores/qdrant/test_from_texts.py deleted file mode 100644 index dadc4ea00..000000000 --- a/libs/community/tests/integration_tests/vectorstores/qdrant/test_from_texts.py +++ /dev/null @@ -1,289 +0,0 @@ -import tempfile -import uuid -from typing import Optional - -import pytest -from langchain_core.documents import Document - -from langchain_community.vectorstores import Qdrant -from langchain_community.vectorstores.qdrant import QdrantException -from tests.integration_tests.vectorstores.fake_embeddings import ( - ConsistentFakeEmbeddings, -) -from tests.integration_tests.vectorstores.qdrant.common import ( - assert_documents_equals, - qdrant_is_not_running, -) - - -def test_qdrant_from_texts_stores_duplicated_texts() -> None: - """Test end to end Qdrant.from_texts stores duplicated texts separately.""" - from qdrant_client import QdrantClient - - collection_name = uuid.uuid4().hex - - with tempfile.TemporaryDirectory() as tmpdir: - vec_store = Qdrant.from_texts( - ["abc", "abc"], - ConsistentFakeEmbeddings(), - collection_name=collection_name, - path=str(tmpdir), - ) - del vec_store - - client = QdrantClient(path=str(tmpdir)) - assert 2 == client.count(collection_name).count - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -def test_qdrant_from_texts_stores_ids( - batch_size: int, vector_name: Optional[str] -) -> None: - """Test end to end Qdrant.from_texts stores provided ids.""" - from qdrant_client import QdrantClient - - collection_name = uuid.uuid4().hex - with tempfile.TemporaryDirectory() as tmpdir: - ids = [ - "fa38d572-4c31-4579-aedc-1960d79df6df", - "cdc1aa36-d6ab-4fb2-8a94-56674fd27484", - ] - vec_store = Qdrant.from_texts( - ["abc", "def"], - ConsistentFakeEmbeddings(), - ids=ids, - collection_name=collection_name, - path=str(tmpdir), - batch_size=batch_size, - vector_name=vector_name, - ) - del vec_store - - client = QdrantClient(path=str(tmpdir)) - assert 2 == client.count(collection_name).count - stored_ids = [point.id for point in client.scroll(collection_name)[0]] - assert set(ids) == set(stored_ids) - - -@pytest.mark.parametrize("vector_name", ["custom-vector"]) -def test_qdrant_from_texts_stores_embeddings_as_named_vectors(vector_name: str) -> None: - """Test end to end Qdrant.from_texts stores named vectors if name is provided.""" - from qdrant_client import QdrantClient - - collection_name = uuid.uuid4().hex - with tempfile.TemporaryDirectory() as tmpdir: - vec_store = Qdrant.from_texts( - ["lorem", "ipsum", "dolor", "sit", "amet"], - ConsistentFakeEmbeddings(), - collection_name=collection_name, - path=str(tmpdir), - vector_name=vector_name, - ) - del vec_store - - client = QdrantClient(path=str(tmpdir)) - assert 5 == client.count(collection_name).count - assert all( - vector_name in point.vector - for point in client.scroll(collection_name, with_vectors=True)[0] - ) - - -@pytest.mark.parametrize("vector_name", [None, "custom-vector"]) -def test_qdrant_from_texts_reuses_same_collection(vector_name: Optional[str]) -> None: - """Test if Qdrant.from_texts reuses the same collection""" - from qdrant_client import QdrantClient - - collection_name = uuid.uuid4().hex - embeddings = ConsistentFakeEmbeddings() - with tempfile.TemporaryDirectory() as tmpdir: - vec_store = Qdrant.from_texts( - ["lorem", "ipsum", "dolor", "sit", "amet"], - embeddings, - collection_name=collection_name, - path=str(tmpdir), - vector_name=vector_name, - ) - del vec_store - - vec_store = Qdrant.from_texts( - ["foo", "bar"], - embeddings, - collection_name=collection_name, - path=str(tmpdir), - vector_name=vector_name, - ) - del vec_store - - client = QdrantClient(path=str(tmpdir)) - assert 7 == client.count(collection_name).count - - -@pytest.mark.parametrize("vector_name", [None, "custom-vector"]) -def test_qdrant_from_texts_raises_error_on_different_dimensionality( - vector_name: Optional[str], -) -> None: - """Test if Qdrant.from_texts raises an exception if dimensionality does not match""" - collection_name = uuid.uuid4().hex - with tempfile.TemporaryDirectory() as tmpdir: - vec_store = Qdrant.from_texts( - ["lorem", "ipsum", "dolor", "sit", "amet"], - ConsistentFakeEmbeddings(dimensionality=10), - collection_name=collection_name, - path=str(tmpdir), - vector_name=vector_name, - ) - del vec_store - - with pytest.raises(QdrantException): - Qdrant.from_texts( - ["foo", "bar"], - ConsistentFakeEmbeddings(dimensionality=5), - collection_name=collection_name, - path=str(tmpdir), - vector_name=vector_name, - ) - - -@pytest.mark.parametrize( - ["first_vector_name", "second_vector_name"], - [ - (None, "custom-vector"), - ("custom-vector", None), - ("my-first-vector", "my-second_vector"), - ], -) -def test_qdrant_from_texts_raises_error_on_different_vector_name( - first_vector_name: Optional[str], - second_vector_name: Optional[str], -) -> None: - """Test if Qdrant.from_texts raises an exception if vector name does not match""" - collection_name = uuid.uuid4().hex - with tempfile.TemporaryDirectory() as tmpdir: - vec_store = Qdrant.from_texts( - ["lorem", "ipsum", "dolor", "sit", "amet"], - ConsistentFakeEmbeddings(dimensionality=10), - collection_name=collection_name, - path=str(tmpdir), - vector_name=first_vector_name, - ) - del vec_store - - with pytest.raises(QdrantException): - Qdrant.from_texts( - ["foo", "bar"], - ConsistentFakeEmbeddings(dimensionality=5), - collection_name=collection_name, - path=str(tmpdir), - vector_name=second_vector_name, - ) - - -def test_qdrant_from_texts_raises_error_on_different_distance() -> None: - """Test if Qdrant.from_texts raises an exception if distance does not match""" - collection_name = uuid.uuid4().hex - with tempfile.TemporaryDirectory() as tmpdir: - vec_store = Qdrant.from_texts( - ["lorem", "ipsum", "dolor", "sit", "amet"], - ConsistentFakeEmbeddings(), - collection_name=collection_name, - path=str(tmpdir), - distance_func="Cosine", - ) - del vec_store - - with pytest.raises(QdrantException) as excinfo: - Qdrant.from_texts( - ["foo", "bar"], - ConsistentFakeEmbeddings(), - collection_name=collection_name, - path=str(tmpdir), - distance_func="Euclid", - ) - - expected_message = ( - "configured for COSINE similarity, but requested EUCLID. Please set " - "`distance_func` parameter to `COSINE`" - ) - assert expected_message in str(excinfo.value) - - -@pytest.mark.parametrize("vector_name", [None, "custom-vector"]) -def test_qdrant_from_texts_recreates_collection_on_force_recreate( - vector_name: Optional[str], -) -> None: - """Test if Qdrant.from_texts recreates the collection even if config mismatches""" - from qdrant_client import QdrantClient - - collection_name = uuid.uuid4().hex - with tempfile.TemporaryDirectory() as tmpdir: - vec_store = Qdrant.from_texts( - ["lorem", "ipsum", "dolor", "sit", "amet"], - ConsistentFakeEmbeddings(dimensionality=10), - collection_name=collection_name, - path=str(tmpdir), - vector_name=vector_name, - ) - del vec_store - - vec_store = Qdrant.from_texts( - ["foo", "bar"], - ConsistentFakeEmbeddings(dimensionality=5), - collection_name=collection_name, - path=str(tmpdir), - vector_name=vector_name, - force_recreate=True, - ) - del vec_store - - client = QdrantClient(path=str(tmpdir)) - assert 2 == client.count(collection_name).count - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) -@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) -def test_qdrant_from_texts_stores_metadatas( - batch_size: int, content_payload_key: str, metadata_payload_key: str -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - metadatas=metadatas, - location=":memory:", - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - batch_size=batch_size, - ) - output = docsearch.similarity_search("foo", k=1) - assert_documents_equals( - output, [Document(page_content="foo", metadata={"page": 0})] - ) - - -@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running") -def test_from_texts_passed_optimizers_config_and_on_disk_payload() -> None: - from qdrant_client import models - - collection_name = uuid.uuid4().hex - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - optimizers_config = models.OptimizersConfigDiff(memmap_threshold=1000) - vec_store = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - metadatas=metadatas, - optimizers_config=optimizers_config, - on_disk_payload=True, - on_disk=True, - collection_name=collection_name, - ) - - collection_info = vec_store.client.get_collection(collection_name) - assert collection_info.config.params.vectors.on_disk is True - assert collection_info.config.optimizer_config.memmap_threshold == 1000 - assert collection_info.config.params.on_disk_payload is True diff --git a/libs/community/tests/integration_tests/vectorstores/qdrant/test_max_marginal_relevance.py b/libs/community/tests/integration_tests/vectorstores/qdrant/test_max_marginal_relevance.py deleted file mode 100644 index 05c743118..000000000 --- a/libs/community/tests/integration_tests/vectorstores/qdrant/test_max_marginal_relevance.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Optional - -import pytest -from langchain_core.documents import Document - -from langchain_community.vectorstores import Qdrant -from tests.integration_tests.vectorstores.fake_embeddings import ( - ConsistentFakeEmbeddings, -) -from tests.integration_tests.vectorstores.qdrant.common import assert_documents_equals - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"]) -@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"]) -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -def test_qdrant_max_marginal_relevance_search( - batch_size: int, - content_payload_key: str, - metadata_payload_key: str, - vector_name: Optional[str], -) -> None: - """Test end to end construction and MRR search.""" - from qdrant_client import models - - filter = models.Filter( - must=[ - models.FieldCondition( - key=f"{metadata_payload_key}.page", - match=models.MatchValue( - value=2, - ), - ), - ], - ) - - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - metadatas=metadatas, - location=":memory:", - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - batch_size=batch_size, - vector_name=vector_name, - distance_func="EUCLID", # Euclid distance used to avoid normalization - ) - output = docsearch.max_marginal_relevance_search( - "foo", k=2, fetch_k=3, lambda_mult=0.0 - ) - assert_documents_equals( - output, - [ - Document(page_content="foo", metadata={"page": 0}), - Document(page_content="baz", metadata={"page": 2}), - ], - ) - - output = docsearch.max_marginal_relevance_search( - "foo", k=2, fetch_k=3, lambda_mult=0.0, filter=filter - ) - assert_documents_equals( - output, - [Document(page_content="baz", metadata={"page": 2})], - ) diff --git a/libs/community/tests/integration_tests/vectorstores/qdrant/test_similarity_search.py b/libs/community/tests/integration_tests/vectorstores/qdrant/test_similarity_search.py deleted file mode 100644 index 5b51be93d..000000000 --- a/libs/community/tests/integration_tests/vectorstores/qdrant/test_similarity_search.py +++ /dev/null @@ -1,284 +0,0 @@ -from typing import Optional - -import numpy as np -import pytest -from langchain_core.documents import Document - -from langchain_community.vectorstores import Qdrant -from tests.integration_tests.vectorstores.fake_embeddings import ( - ConsistentFakeEmbeddings, -) -from tests.integration_tests.vectorstores.qdrant.common import assert_documents_equals - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) -@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -def test_qdrant_similarity_search( - batch_size: int, - content_payload_key: str, - metadata_payload_key: str, - vector_name: Optional[str], -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - location=":memory:", - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - batch_size=batch_size, - vector_name=vector_name, - ) - output = docsearch.similarity_search("foo", k=1) - assert_documents_equals(actual=output, expected=[Document(page_content="foo")]) - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) -@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -def test_qdrant_similarity_search_by_vector( - batch_size: int, - content_payload_key: str, - metadata_payload_key: str, - vector_name: Optional[str], -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - location=":memory:", - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - batch_size=batch_size, - vector_name=vector_name, - ) - embeddings = ConsistentFakeEmbeddings().embed_query("foo") - output = docsearch.similarity_search_by_vector(embeddings, k=1) - assert_documents_equals(output, [Document(page_content="foo")]) - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) -@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -def test_qdrant_similarity_search_with_score_by_vector( - batch_size: int, - content_payload_key: str, - metadata_payload_key: str, - vector_name: Optional[str], -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - location=":memory:", - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - batch_size=batch_size, - vector_name=vector_name, - ) - embeddings = ConsistentFakeEmbeddings().embed_query("foo") - output = docsearch.similarity_search_with_score_by_vector(embeddings, k=1) - assert len(output) == 1 - document, score = output[0] - assert_documents_equals(actual=[document], expected=[Document(page_content="foo")]) - assert score >= 0 - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -def test_qdrant_similarity_search_filters( - batch_size: int, vector_name: Optional[str] -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - metadatas = [ - {"page": i, "metadata": {"page": i + 1, "pages": [i + 2, -1]}} - for i in range(len(texts)) - ] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - metadatas=metadatas, - location=":memory:", - batch_size=batch_size, - vector_name=vector_name, - ) - - output = docsearch.similarity_search( - "foo", k=1, filter={"page": 1, "metadata": {"page": 2, "pages": [3]}} - ) - - assert_documents_equals( - actual=output, - expected=[ - Document( - page_content="bar", - metadata={"page": 1, "metadata": {"page": 2, "pages": [3, -1]}}, - ) - ], - ) - - -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -def test_qdrant_similarity_search_with_relevance_score_no_threshold( - vector_name: Optional[str], -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - metadatas = [ - {"page": i, "metadata": {"page": i + 1, "pages": [i + 2, -1]}} - for i in range(len(texts)) - ] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - metadatas=metadatas, - location=":memory:", - vector_name=vector_name, - ) - output = docsearch.similarity_search_with_relevance_scores( - "foo", k=3, score_threshold=None - ) - assert len(output) == 3 - for i in range(len(output)): - assert round(output[i][1], 2) >= 0 - assert round(output[i][1], 2) <= 1 - - -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -def test_qdrant_similarity_search_with_relevance_score_with_threshold( - vector_name: Optional[str], -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - metadatas = [ - {"page": i, "metadata": {"page": i + 1, "pages": [i + 2, -1]}} - for i in range(len(texts)) - ] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - metadatas=metadatas, - location=":memory:", - vector_name=vector_name, - ) - - score_threshold = 0.98 - kwargs = {"score_threshold": score_threshold} - output = docsearch.similarity_search_with_relevance_scores("foo", k=3, **kwargs) - assert len(output) == 1 - assert all([score >= score_threshold for _, score in output]) - - -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -def test_qdrant_similarity_search_with_relevance_score_with_threshold_and_filter( - vector_name: Optional[str], -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - metadatas = [ - {"page": i, "metadata": {"page": i + 1, "pages": [i + 2, -1]}} - for i in range(len(texts)) - ] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - metadatas=metadatas, - location=":memory:", - vector_name=vector_name, - ) - score_threshold = 0.99 # for almost exact match - # test negative filter condition - negative_filter = {"page": 1, "metadata": {"page": 2, "pages": [3]}} - kwargs = {"filter": negative_filter, "score_threshold": score_threshold} - output = docsearch.similarity_search_with_relevance_scores("foo", k=3, **kwargs) - assert len(output) == 0 - # test positive filter condition - positive_filter = {"page": 0, "metadata": {"page": 1, "pages": [2]}} - kwargs = {"filter": positive_filter, "score_threshold": score_threshold} - output = docsearch.similarity_search_with_relevance_scores("foo", k=3, **kwargs) - assert len(output) == 1 - assert all([score >= score_threshold for _, score in output]) - - -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -def test_qdrant_similarity_search_filters_with_qdrant_filters( - vector_name: Optional[str], -) -> None: - """Test end to end construction and search.""" - from qdrant_client.http import models as rest - - texts = ["foo", "bar", "baz"] - metadatas = [ - {"page": i, "details": {"page": i + 1, "pages": [i + 2, -1]}} - for i in range(len(texts)) - ] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - metadatas=metadatas, - location=":memory:", - vector_name=vector_name, - ) - - qdrant_filter = rest.Filter( - must=[ - rest.FieldCondition( - key="metadata.page", - match=rest.MatchValue(value=1), - ), - rest.FieldCondition( - key="metadata.details.page", - match=rest.MatchValue(value=2), - ), - rest.FieldCondition( - key="metadata.details.pages", - match=rest.MatchAny(any=[3]), - ), - ] - ) - output = docsearch.similarity_search("foo", k=1, filter=qdrant_filter) - assert_documents_equals( - actual=output, - expected=[ - Document( - page_content="bar", - metadata={"page": 1, "details": {"page": 2, "pages": [3, -1]}}, - ) - ], - ) - - -@pytest.mark.parametrize("batch_size", [1, 64]) -@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) -@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) -@pytest.mark.parametrize("vector_name", [None, "my-vector"]) -def test_qdrant_similarity_search_with_relevance_scores( - batch_size: int, - content_payload_key: str, - metadata_payload_key: str, - vector_name: Optional[str], -) -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - docsearch = Qdrant.from_texts( - texts, - ConsistentFakeEmbeddings(), - location=":memory:", - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - batch_size=batch_size, - vector_name=vector_name, - ) - output = docsearch.similarity_search_with_relevance_scores("foo", k=3) - - assert all( - (1 >= score or np.isclose(score, 1)) and score >= 0 for _, score in output - ) diff --git a/libs/community/tests/integration_tests/vectorstores/test_deeplake.py b/libs/community/tests/integration_tests/vectorstores/test_deeplake.py deleted file mode 100644 index 54261c8da..000000000 --- a/libs/community/tests/integration_tests/vectorstores/test_deeplake.py +++ /dev/null @@ -1,274 +0,0 @@ -"""Test Deep Lake functionality.""" - -from collections.abc import Iterator - -import pytest -from langchain_core.documents import Document -from pytest import FixtureRequest - -from langchain_community.vectorstores import DeepLake -from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings - - -@pytest.fixture -def deeplake_datastore() -> Iterator[DeepLake]: - texts = ["foo", "bar", "baz"] - metadatas = [{"page": str(i)} for i in range(len(texts))] - docsearch = DeepLake.from_texts( - dataset_path="./test_path", - texts=texts, - metadatas=metadatas, - embedding_function=FakeEmbeddings(), - overwrite=True, - ) - yield docsearch - - docsearch.delete_dataset() - - -@pytest.fixture(params=["L1", "L2", "max", "cos"]) -def distance_metric(request: FixtureRequest) -> str: - return request.param - - -def test_deeplake() -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - docsearch = DeepLake.from_texts( - dataset_path="mem://test_path", texts=texts, embedding=FakeEmbeddings() - ) - output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo")] - - -def test_deeplake_with_metadatas() -> None: - """Test end to end construction and search.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": str(i)} for i in range(len(texts))] - docsearch = DeepLake.from_texts( - dataset_path="mem://test_path", - texts=texts, - embedding=FakeEmbeddings(), - metadatas=metadatas, - ) - output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo", metadata={"page": "0"})] - - -def test_deeplake_with_persistence(deeplake_datastore: DeepLake) -> None: - """Test end to end construction and search, with persistence.""" - output = deeplake_datastore.similarity_search("foo", k=1) - assert output == [Document(page_content="foo", metadata={"page": "0"})] - - # Get a new VectorStore from the persisted directory - docsearch = DeepLake( - dataset_path=deeplake_datastore.vectorstore.dataset_handler.path, - embedding_function=FakeEmbeddings(), - ) - output = docsearch.similarity_search("foo", k=1) - - # Clean up - docsearch.delete_dataset() - - # Persist doesn't need to be called again - # Data will be automatically persisted on object deletion - # Or on program exit - - -def test_deeplake_overwrite_flag(deeplake_datastore: DeepLake) -> None: - """Test overwrite behavior""" - dataset_path = deeplake_datastore.vectorstore.dataset_handler.path - - output = deeplake_datastore.similarity_search("foo", k=1) - assert output == [Document(page_content="foo", metadata={"page": "0"})] - - # Get a new VectorStore from the persisted directory, with no overwrite (implicit) - docsearch = DeepLake( - dataset_path=dataset_path, - embedding_function=FakeEmbeddings(), - ) - output = docsearch.similarity_search("foo", k=1) - # assert page still present - assert output == [Document(page_content="foo", metadata={"page": "0"})] - - # Get a new VectorStore from the persisted directory, with no overwrite (explicit) - docsearch = DeepLake( - dataset_path=dataset_path, - embedding_function=FakeEmbeddings(), - overwrite=False, - ) - output = docsearch.similarity_search("foo", k=1) - # assert page still present - assert output == [Document(page_content="foo", metadata={"page": "0"})] - - # Get a new VectorStore from the persisted directory, with overwrite - docsearch = DeepLake( - dataset_path=dataset_path, - embedding_function=FakeEmbeddings(), - overwrite=True, - ) - with pytest.raises(ValueError): - output = docsearch.similarity_search("foo", k=1) - - -def test_similarity_search(deeplake_datastore: DeepLake) -> None: - """Test similarity search.""" - distance_metric = "cos" - output = deeplake_datastore.similarity_search( - "foo", k=1, distance_metric=distance_metric - ) - assert output == [Document(page_content="foo", metadata={"page": "0"})] - - tql_query = ( - f"SELECT * WHERE " - f"id=='{deeplake_datastore.vectorstore.dataset.id[0].numpy()[0]}'" - ) - - output = deeplake_datastore.similarity_search( - query="foo", tql_query=tql_query, k=1, distance_metric=distance_metric - ) - assert len(output) == 1 - - -def test_similarity_search_by_vector( - deeplake_datastore: DeepLake, distance_metric: str -) -> None: - """Test similarity search by vector.""" - embeddings = FakeEmbeddings().embed_documents(["foo", "bar", "baz"]) - output = deeplake_datastore.similarity_search_by_vector( - embeddings[1], k=1, distance_metric=distance_metric - ) - assert output == [Document(page_content="bar", metadata={"page": "1"})] - deeplake_datastore.delete_dataset() - - -def test_similarity_search_with_score( - deeplake_datastore: DeepLake, distance_metric: str -) -> None: - """Test similarity search with score.""" - deeplake_datastore.vectorstore.summary() - output, score = deeplake_datastore.similarity_search_with_score( - "foo", k=1, distance_metric=distance_metric - )[0] - assert output == Document(page_content="foo", metadata={"page": "0"}) - if distance_metric == "cos": - assert score == 1.0 - else: - assert score == 0.0 - deeplake_datastore.delete_dataset() - - -def test_similarity_search_with_filter( - deeplake_datastore: DeepLake, distance_metric: str -) -> None: - """Test similarity search.""" - - output = deeplake_datastore.similarity_search( - "foo", - k=1, - distance_metric=distance_metric, - filter={"metadata": {"page": "1"}}, - ) - assert output == [Document(page_content="bar", metadata={"page": "1"})] - deeplake_datastore.delete_dataset() - - -def test_max_marginal_relevance_search(deeplake_datastore: DeepLake) -> None: - """Test max marginal relevance search by vector.""" - - output = deeplake_datastore.max_marginal_relevance_search("foo", k=1, fetch_k=2) - - assert output == [Document(page_content="foo", metadata={"page": "0"})] - - embeddings = FakeEmbeddings().embed_documents(["foo", "bar", "baz"]) - output = deeplake_datastore.max_marginal_relevance_search_by_vector( - embeddings[0], k=1, fetch_k=2 - ) - - assert output == [Document(page_content="foo", metadata={"page": "0"})] - deeplake_datastore.delete_dataset() - - -def test_delete_dataset_by_ids(deeplake_datastore: DeepLake) -> None: - """Test delete dataset.""" - id = deeplake_datastore.vectorstore.dataset.id.data()["value"][0] - deeplake_datastore.delete(ids=[id]) - assert ( - deeplake_datastore.similarity_search( - "foo", k=1, filter={"metadata": {"page": "0"}} - ) - == [] - ) - assert len(deeplake_datastore.vectorstore) == 2 - - deeplake_datastore.delete_dataset() - - -def test_delete_dataset_by_filter(deeplake_datastore: DeepLake) -> None: - """Test delete dataset.""" - deeplake_datastore.delete(filter={"metadata": {"page": "1"}}) - assert ( - deeplake_datastore.similarity_search( - "bar", k=1, filter={"metadata": {"page": "1"}} - ) - == [] - ) - assert len(deeplake_datastore.vectorstore.dataset) == 2 - - deeplake_datastore.delete_dataset() - - -def test_delete_by_path(deeplake_datastore: DeepLake) -> None: - """Test delete dataset.""" - import deeplake - - path = deeplake_datastore.dataset_path - DeepLake.force_delete_by_path(path) - assert not deeplake.exists(path) - - -def test_add_texts(deeplake_datastore: DeepLake) -> None: - """Test add_texts dataset.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": str(i)} for i in range(len(texts))] - - deeplake_datastore.add_texts( - texts=texts, - metadatas=metadatas, - ) - - with pytest.raises(TypeError): - deeplake_datastore.add_texts( - texts=texts, - metada=metadatas, - ) - - -def test_ids_backwards_compatibility() -> None: - """Test that ids are backwards compatible.""" - db = DeepLake( - dataset_path="mem://test_path", - embedding_function=FakeEmbeddings(), - tensor_params=[ - {"name": "ids", "htype": "text"}, - {"name": "text", "htype": "text"}, - {"name": "embedding", "htype": "embedding"}, - {"name": "metadata", "htype": "json"}, - ], - ) - db.vectorstore.add( - ids=["1", "2", "3"], - text=["foo", "bar", "baz"], - embedding=FakeEmbeddings().embed_documents(["foo", "bar", "baz"]), - metadata=[{"page": str(i)} for i in range(3)], - ) - output = db.similarity_search("foo", k=1) - assert len(output) == 1 - - -def test_similarity_search_should_error_out_when_not_supported_kwargs_are_provided( - deeplake_datastore: DeepLake, -) -> None: - """Test that ids are backwards compatible.""" - with pytest.raises(TypeError): - deeplake_datastore.similarity_search("foo", k=1, not_supported_kwarg=True) diff --git a/libs/community/tests/integration_tests/vectorstores/test_hanavector.py b/libs/community/tests/integration_tests/vectorstores/test_hanavector.py deleted file mode 100644 index 5802d4ace..000000000 --- a/libs/community/tests/integration_tests/vectorstores/test_hanavector.py +++ /dev/null @@ -1,1624 +0,0 @@ -"""Test HANA vectorstore functionality.""" - -import os -import random -from types import ModuleType -from typing import Any, Dict, List - -import numpy as np -import pytest - -from langchain_community.vectorstores import HanaDB -from langchain_community.vectorstores.utils import DistanceStrategy -from tests.integration_tests.vectorstores.fake_embeddings import ( - ConsistentFakeEmbeddings, -) -from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import ( - DOCUMENTS, - TYPE_1_FILTERING_TEST_CASES, - TYPE_2_FILTERING_TEST_CASES, - TYPE_3_FILTERING_TEST_CASES, - TYPE_4_FILTERING_TEST_CASES, - TYPE_5_FILTERING_TEST_CASES, -) - -TYPE_4B_FILTERING_TEST_CASES = [ - # Test $nin, which is missing in TYPE_4_FILTERING_TEST_CASES - ( - {"name": {"$nin": ["adam", "bob"]}}, - [3], - ), -] - -try: - from hdbcli import dbapi - - hanadb_installed = True -except ImportError: - hanadb_installed = False - - -class NormalizedFakeEmbeddings(ConsistentFakeEmbeddings): - """Fake embeddings with normalization. For testing purposes.""" - - def normalize(self, vector: List[float]) -> List[float]: - """Normalize vector.""" - return [float(v / np.linalg.norm(vector)) for v in vector] - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - return [self.normalize(v) for v in super().embed_documents(texts)] - - def embed_query(self, text: str) -> List[float]: - return self.normalize(super().embed_query(text)) - - -embedding = NormalizedFakeEmbeddings() - - -class ConfigData: - def __init__(self) -> None: - self.conn: dbapi.Connection = None - self.schema_name: str = "" - - -test_setup = ConfigData() - - -def generateSchemaName(cursor: "dbapi.Cursor") -> str: - # return "Langchain" - cursor.execute( - "SELECT REPLACE(CURRENT_UTCDATE, '-', '') || '_' || BINTOHEX(SYSUUID) FROM " - "DUMMY;" - ) - if cursor.has_result_set(): - rows = cursor.fetchall() - uid = rows[0][0] - else: - uid = random.randint(1, 100000000) - return f"VEC_{uid}" - - -def setup_module(module: ModuleType) -> None: - test_setup.conn = dbapi.connect( - address=os.environ.get("HANA_DB_ADDRESS"), - port=os.environ.get("HANA_DB_PORT"), - user=os.environ.get("HANA_DB_USER"), - password=os.environ.get("HANA_DB_PASSWORD"), - autocommit=True, - sslValidateCertificate=False, - # encrypt=True - ) - try: - cur = test_setup.conn.cursor() - test_setup.schema_name = generateSchemaName(cur) - sql_str = f"CREATE SCHEMA {test_setup.schema_name}" - cur.execute(sql_str) - sql_str = f"SET SCHEMA {test_setup.schema_name}" - cur.execute(sql_str) - except dbapi.ProgrammingError: - pass - finally: - cur.close() - - -def teardown_module(module: ModuleType) -> None: - # return - try: - cur = test_setup.conn.cursor() - sql_str = f"DROP SCHEMA {test_setup.schema_name} CASCADE" - cur.execute(sql_str) - except dbapi.ProgrammingError: - pass - finally: - cur.close() - - -@pytest.fixture -def texts() -> List[str]: - return ["foo", "bar", "baz", "bak", "cat"] - - -@pytest.fixture -def metadatas() -> List[dict[str, Any]]: - return [ - {"start": 0, "end": 100, "quality": "good", "ready": True}, - {"start": 100, "end": 200, "quality": "bad", "ready": False}, - {"start": 200, "end": 300, "quality": "ugly", "ready": True}, - {"start": 200, "quality": "ugly", "ready": True, "Owner": "Steve"}, - {"start": 300, "quality": "ugly", "Owner": "Steve"}, - ] - - -def drop_table(connection: "dbapi.Connection", table_name: str) -> None: - try: - cur = connection.cursor() - sql_str = f"DROP TABLE {table_name}" - cur.execute(sql_str) - except dbapi.ProgrammingError: - pass - finally: - cur.close() - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_non_existing_table() -> None: - """Test end to end construction and search.""" - table_name = "NON_EXISTING" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectordb = HanaDB( - connection=test_setup.conn, - embedding=embedding, - distance_strategy=DistanceStrategy.COSINE, - table_name=table_name, - ) - - assert vectordb._table_exists(table_name) - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_table_with_missing_columns() -> None: - table_name = "EXISTING_MISSING_COLS" - try: - drop_table(test_setup.conn, table_name) - cur = test_setup.conn.cursor() - sql_str = f"CREATE TABLE {table_name}(WRONG_COL NVARCHAR(500));" - cur.execute(sql_str) - finally: - cur.close() - - # Check if table is created - exception_occurred = False - try: - HanaDB( - connection=test_setup.conn, - embedding=embedding, - distance_strategy=DistanceStrategy.COSINE, - table_name=table_name, - ) - exception_occurred = False - except AttributeError: - exception_occurred = True - assert exception_occurred - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_table_with_nvarchar_content(texts: List[str]) -> None: - table_name = "EXISTING_NVARCHAR" - content_column = "TEST_TEXT" - metadata_column = "TEST_META" - vector_column = "TEST_VECTOR" - try: - drop_table(test_setup.conn, table_name) - cur = test_setup.conn.cursor() - sql_str = ( - f"CREATE TABLE {table_name}({content_column} NVARCHAR(2048), " - f"{metadata_column} NVARCHAR(2048), {vector_column} REAL_VECTOR);" - ) - cur.execute(sql_str) - finally: - cur.close() - - vectordb = HanaDB( - connection=test_setup.conn, - embedding=embedding, - distance_strategy=DistanceStrategy.COSINE, - table_name=table_name, - content_column=content_column, - metadata_column=metadata_column, - vector_column=vector_column, - ) - - vectordb.add_texts(texts=texts) - - # check that embeddings have been created in the table - number_of_texts = len(texts) - number_of_rows = -1 - sql_str = f"SELECT COUNT(*) FROM {table_name}" - cur = test_setup.conn.cursor() - cur.execute(sql_str) - if cur.has_result_set(): - rows = cur.fetchall() - number_of_rows = rows[0][0] - assert number_of_rows == number_of_texts - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_table_with_wrong_typed_columns() -> None: - table_name = "EXISTING_WRONG_TYPES" - content_column = "DOC_TEXT" - metadata_column = "DOC_META" - vector_column = "DOC_VECTOR" - try: - drop_table(test_setup.conn, table_name) - cur = test_setup.conn.cursor() - sql_str = ( - f"CREATE TABLE {table_name}({content_column} INTEGER, " - f"{metadata_column} INTEGER, {vector_column} INTEGER);" - ) - cur.execute(sql_str) - finally: - cur.close() - - # Check if table is created - exception_occurred = False - try: - HanaDB( - connection=test_setup.conn, - embedding=embedding, - distance_strategy=DistanceStrategy.COSINE, - table_name=table_name, - ) - exception_occurred = False - except AttributeError as err: - print(err) # noqa: T201 - exception_occurred = True - assert exception_occurred - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_non_existing_table_fixed_vector_length() -> None: - """Test end to end construction and search.""" - table_name = "NON_EXISTING" - vector_column = "MY_VECTOR" - vector_column_length = 42 - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectordb = HanaDB( - connection=test_setup.conn, - embedding=embedding, - distance_strategy=DistanceStrategy.COSINE, - table_name=table_name, - vector_column=vector_column, - vector_column_length=vector_column_length, - ) - - assert vectordb._table_exists(table_name) - vectordb._check_column( - table_name, vector_column, ["REAL_VECTOR"], vector_column_length - ) - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_add_texts(texts: List[str]) -> None: - table_name = "TEST_TABLE_ADD_TEXTS" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectordb = HanaDB( - connection=test_setup.conn, embedding=embedding, table_name=table_name - ) - - vectordb.add_texts(texts=texts) - - # check that embeddings have been created in the table - number_of_texts = len(texts) - number_of_rows = -1 - sql_str = f"SELECT COUNT(*) FROM {table_name}" - cur = test_setup.conn.cursor() - cur.execute(sql_str) - if cur.has_result_set(): - rows = cur.fetchall() - number_of_rows = rows[0][0] - assert number_of_rows == number_of_texts - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_from_texts(texts: List[str]) -> None: - table_name = "TEST_TABLE_FROM_TEXTS" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - embedding=embedding, - table_name=table_name, - ) - # test if vectorDB is instance of HanaDB - assert isinstance(vectorDB, HanaDB) - - # check that embeddings have been created in the table - number_of_texts = len(texts) - number_of_rows = -1 - sql_str = f"SELECT COUNT(*) FROM {table_name}" - cur = test_setup.conn.cursor() - cur.execute(sql_str) - if cur.has_result_set(): - rows = cur.fetchall() - number_of_rows = rows[0][0] - assert number_of_rows == number_of_texts - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_similarity_search_simple(texts: List[str]) -> None: - table_name = "TEST_TABLE_SEARCH_SIMPLE" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - embedding=embedding, - table_name=table_name, - ) - - assert texts[0] == vectorDB.similarity_search(texts[0], 1)[0].page_content - assert texts[1] != vectorDB.similarity_search(texts[0], 1)[0].page_content - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_similarity_search_by_vector_simple(texts: List[str]) -> None: - table_name = "TEST_TABLE_SEARCH_SIMPLE_VECTOR" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - embedding=embedding, - table_name=table_name, - ) - - vector = embedding.embed_query(texts[0]) - assert texts[0] == vectorDB.similarity_search_by_vector(vector, 1)[0].page_content - assert texts[1] != vectorDB.similarity_search_by_vector(vector, 1)[0].page_content - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_similarity_search_simple_euclidean_distance( - texts: List[str], -) -> None: - table_name = "TEST_TABLE_SEARCH_EUCLIDIAN" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - embedding=embedding, - table_name=table_name, - distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE, - ) - - assert texts[0] == vectorDB.similarity_search(texts[0], 1)[0].page_content - assert texts[1] != vectorDB.similarity_search(texts[0], 1)[0].page_content - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_similarity_search_with_metadata( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "TEST_TABLE_METADATA" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=metadatas, - embedding=embedding, - table_name=table_name, - ) - - search_result = vectorDB.similarity_search(texts[0], 3) - - assert texts[0] == search_result[0].page_content - assert metadatas[0]["start"] == search_result[0].metadata["start"] - assert metadatas[0]["end"] == search_result[0].metadata["end"] - assert texts[1] != search_result[0].page_content - assert metadatas[1]["start"] != search_result[0].metadata["start"] - assert metadatas[1]["end"] != search_result[0].metadata["end"] - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_similarity_search_with_metadata_filter( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "TEST_TABLE_FILTER" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=metadatas, - embedding=embedding, - table_name=table_name, - ) - - search_result = vectorDB.similarity_search(texts[0], 3, filter={"start": 100}) - - assert len(search_result) == 1 - assert texts[1] == search_result[0].page_content - assert metadatas[1]["start"] == search_result[0].metadata["start"] - assert metadatas[1]["end"] == search_result[0].metadata["end"] - - search_result = vectorDB.similarity_search( - texts[0], 3, filter={"start": 100, "end": 150} - ) - assert len(search_result) == 0 - - search_result = vectorDB.similarity_search( - texts[0], 3, filter={"start": 100, "end": 200} - ) - assert len(search_result) == 1 - assert texts[1] == search_result[0].page_content - assert metadatas[1]["start"] == search_result[0].metadata["start"] - assert metadatas[1]["end"] == search_result[0].metadata["end"] - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_similarity_search_with_metadata_filter_string( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "TEST_TABLE_FILTER_STRING" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=metadatas, - embedding=embedding, - table_name=table_name, - ) - - search_result = vectorDB.similarity_search(texts[0], 3, filter={"quality": "bad"}) - - assert len(search_result) == 1 - assert texts[1] == search_result[0].page_content - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_similarity_search_with_metadata_filter_bool( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "TEST_TABLE_FILTER_BOOL" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=metadatas, - embedding=embedding, - table_name=table_name, - ) - - search_result = vectorDB.similarity_search(texts[0], 3, filter={"ready": False}) - - assert len(search_result) == 1 - assert texts[1] == search_result[0].page_content - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_similarity_search_with_metadata_filter_invalid_type( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "TEST_TABLE_FILTER_INVALID_TYPE" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=metadatas, - embedding=embedding, - table_name=table_name, - ) - - exception_occurred = False - try: - vectorDB.similarity_search(texts[0], 3, filter={"wrong_type": 0.1}) - except ValueError: - exception_occurred = True - assert exception_occurred - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_similarity_search_with_score( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "TEST_TABLE_SCORE" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - embedding=embedding, - table_name=table_name, - ) - - search_result = vectorDB.similarity_search_with_score(texts[0], 3) - - assert search_result[0][0].page_content == texts[0] - assert search_result[0][1] == 1.0 - assert search_result[1][1] <= search_result[0][1] - assert search_result[2][1] <= search_result[1][1] - assert search_result[2][1] >= 0.0 - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_similarity_search_with_relevance_score( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "TEST_TABLE_REL_SCORE" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - embedding=embedding, - table_name=table_name, - ) - - search_result = vectorDB.similarity_search_with_relevance_scores(texts[0], 3) - - assert search_result[0][0].page_content == texts[0] - assert search_result[0][1] == 1.0 - assert search_result[1][1] <= search_result[0][1] - assert search_result[2][1] <= search_result[1][1] - assert search_result[2][1] >= 0.0 - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_similarity_search_with_relevance_score_with_euclidian_distance( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "TEST_TABLE_REL_SCORE_EUCLIDIAN" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - embedding=embedding, - table_name=table_name, - distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE, - ) - - search_result = vectorDB.similarity_search_with_relevance_scores(texts[0], 3) - - assert search_result[0][0].page_content == texts[0] - assert search_result[0][1] == 1.0 - assert search_result[1][1] <= search_result[0][1] - assert search_result[2][1] <= search_result[1][1] - assert search_result[2][1] >= 0.0 - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_similarity_search_with_score_with_euclidian_distance( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "TEST_TABLE_SCORE_DISTANCE" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - embedding=embedding, - table_name=table_name, - distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE, - ) - - search_result = vectorDB.similarity_search_with_score(texts[0], 3) - - assert search_result[0][0].page_content == texts[0] - assert search_result[0][1] == 0.0 - assert search_result[1][1] >= search_result[0][1] - assert search_result[2][1] >= search_result[1][1] - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_delete_with_filter(texts: List[str], metadatas: List[dict]) -> None: - table_name = "TEST_TABLE_DELETE_FILTER" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Fill table - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=metadatas, - embedding=embedding, - table_name=table_name, - ) - - search_result = vectorDB.similarity_search(texts[0], 10) - assert len(search_result) == 5 - - # Delete one of the three entries - assert vectorDB.delete(filter={"start": 100, "end": 200}) - - search_result = vectorDB.similarity_search(texts[0], 10) - assert len(search_result) == 4 - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -async def test_hanavector_delete_with_filter_async( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "TEST_TABLE_DELETE_FILTER_ASYNC" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Fill table - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=metadatas, - embedding=embedding, - table_name=table_name, - ) - - search_result = vectorDB.similarity_search(texts[0], 10) - assert len(search_result) == 5 - - # Delete one of the three entries - assert await vectorDB.adelete(filter={"start": 100, "end": 200}) - - search_result = vectorDB.similarity_search(texts[0], 10) - assert len(search_result) == 4 - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_delete_all_with_empty_filter( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "TEST_TABLE_DELETE_ALL" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Fill table - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=metadatas, - embedding=embedding, - table_name=table_name, - ) - - search_result = vectorDB.similarity_search(texts[0], 3) - assert len(search_result) == 3 - - # Delete all entries - assert vectorDB.delete(filter={}) - - search_result = vectorDB.similarity_search(texts[0], 3) - assert len(search_result) == 0 - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_delete_called_wrong( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "TEST_TABLE_DELETE_FILTER_WRONG" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Fill table - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=metadatas, - embedding=embedding, - table_name=table_name, - ) - - # Delete without filter parameter - exception_occurred = False - try: - vectorDB.delete() - except ValueError: - exception_occurred = True - assert exception_occurred - - # Delete with ids parameter - exception_occurred = False - try: - vectorDB.delete(ids=["id1", "id"], filter={"start": 100, "end": 200}) - except ValueError: - exception_occurred = True - assert exception_occurred - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_max_marginal_relevance_search(texts: List[str]) -> None: - table_name = "TEST_TABLE_MAX_RELEVANCE" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - embedding=embedding, - table_name=table_name, - ) - - search_result = vectorDB.max_marginal_relevance_search(texts[0], k=2, fetch_k=20) - - assert len(search_result) == 2 - assert search_result[0].page_content == texts[0] - assert search_result[1].page_content != texts[0] - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_max_marginal_relevance_search_vector(texts: List[str]) -> None: - table_name = "TEST_TABLE_MAX_RELEVANCE_VECTOR" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - embedding=embedding, - table_name=table_name, - ) - - search_result = vectorDB.max_marginal_relevance_search_by_vector( - embedding.embed_query(texts[0]), k=2, fetch_k=20 - ) - - assert len(search_result) == 2 - assert search_result[0].page_content == texts[0] - assert search_result[1].page_content != texts[0] - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -async def test_hanavector_max_marginal_relevance_search_async(texts: List[str]) -> None: - table_name = "TEST_TABLE_MAX_RELEVANCE_ASYNC" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - embedding=embedding, - table_name=table_name, - ) - - search_result = await vectorDB.amax_marginal_relevance_search( - texts[0], k=2, fetch_k=20 - ) - - assert len(search_result) == 2 - assert search_result[0].page_content == texts[0] - assert search_result[1].page_content != texts[0] - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_filter_prepared_statement_params( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "TEST_TABLE_FILTER_PARAM" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - # Check if table is created - HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=metadatas, - embedding=embedding, - table_name=table_name, - ) - - cur = test_setup.conn.cursor() - sql_str = ( - f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.start') = '100'" - ) - cur.execute(sql_str) - rows = cur.fetchall() - assert len(rows) == 1 - - query_value = 100 - sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.start') = ?" - cur.execute(sql_str, (query_value)) - rows = cur.fetchall() - assert len(rows) == 1 - - sql_str = ( - f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.quality') = 'good'" - ) - cur.execute(sql_str) - rows = cur.fetchall() - assert len(rows) == 1 - - query_value = "good" # type: ignore[assignment] - sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.quality') = ?" - cur.execute(sql_str, (query_value)) - rows = cur.fetchall() - assert len(rows) == 1 - - sql_str = ( - f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = false" - ) - cur.execute(sql_str) - rows = cur.fetchall() - assert len(rows) == 1 - - # query_value = True - query_value = "true" # type: ignore[assignment] - sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?" - cur.execute(sql_str, (query_value)) - rows = cur.fetchall() - assert len(rows) == 3 - - # query_value = False - query_value = "false" # type: ignore[assignment] - sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?" - cur.execute(sql_str, (query_value)) - rows = cur.fetchall() - assert len(rows) == 1 - - -def test_invalid_metadata_keys(texts: List[str], metadatas: List[dict]) -> None: - table_name = "TEST_TABLE_INVALID_METADATA" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - invalid_metadatas = [ - {"sta rt": 0, "end": 100, "quality": "good", "ready": True}, - ] - exception_occurred = False - try: - HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=invalid_metadatas, - embedding=embedding, - table_name=table_name, - ) - except ValueError: - exception_occurred = True - assert exception_occurred - - invalid_metadatas = [ - {"sta/nrt": 0, "end": 100, "quality": "good", "ready": True}, - ] - exception_occurred = False - try: - HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=invalid_metadatas, - embedding=embedding, - table_name=table_name, - ) - except ValueError: - exception_occurred = True - assert exception_occurred - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_table_mixed_case_names(texts: List[str]) -> None: - table_name = "MyTableName" - content_column = "TextColumn" - metadata_column = "MetaColumn" - vector_column = "VectorColumn" - - vectordb = HanaDB( - connection=test_setup.conn, - embedding=embedding, - distance_strategy=DistanceStrategy.COSINE, - table_name=table_name, - content_column=content_column, - metadata_column=metadata_column, - vector_column=vector_column, - ) - - vectordb.add_texts(texts=texts) - - # check that embeddings have been created in the table - number_of_texts = len(texts) - number_of_rows = -1 - sql_str = f'SELECT COUNT(*) FROM "{table_name}"' - cur = test_setup.conn.cursor() - cur.execute(sql_str) - if cur.has_result_set(): - rows = cur.fetchall() - number_of_rows = rows[0][0] - assert number_of_rows == number_of_texts - - # check results of similarity search - assert texts[0] == vectordb.similarity_search(texts[0], 1)[0].page_content - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_hanavector_enhanced_filter_1() -> None: - table_name = "TEST_TABLE_ENHANCED_FILTER_1" - # Delete table if it exists - drop_table(test_setup.conn, table_name) - - vectorDB = HanaDB( - connection=test_setup.conn, - embedding=embedding, - table_name=table_name, - ) - - vectorDB.add_documents(DOCUMENTS) - - -@pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES) -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_pgvector_with_with_metadata_filters_1( - test_filter: Dict[str, Any], - expected_ids: List[int], -) -> None: - table_name = "TEST_TABLE_ENHANCED_FILTER_1" - drop_table(test_setup.conn, table_name) - - vectorDB = HanaDB( - connection=test_setup.conn, - embedding=embedding, - table_name=table_name, - ) - - vectorDB.add_documents(DOCUMENTS) - - docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) - ids = [doc.metadata["id"] for doc in docs] - assert len(ids) == len(expected_ids), test_filter - assert set(ids).issubset(expected_ids), test_filter - - -@pytest.mark.parametrize("test_filter, expected_ids", TYPE_2_FILTERING_TEST_CASES) -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_pgvector_with_with_metadata_filters_2( - test_filter: Dict[str, Any], - expected_ids: List[int], -) -> None: - table_name = "TEST_TABLE_ENHANCED_FILTER_2" - drop_table(test_setup.conn, table_name) - - vectorDB = HanaDB( - connection=test_setup.conn, - embedding=embedding, - table_name=table_name, - ) - - vectorDB.add_documents(DOCUMENTS) - - docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) - ids = [doc.metadata["id"] for doc in docs] - assert len(ids) == len(expected_ids), test_filter - assert set(ids).issubset(expected_ids), test_filter - - -@pytest.mark.parametrize("test_filter, expected_ids", TYPE_3_FILTERING_TEST_CASES) -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_pgvector_with_with_metadata_filters_3( - test_filter: Dict[str, Any], - expected_ids: List[int], -) -> None: - table_name = "TEST_TABLE_ENHANCED_FILTER_3" - drop_table(test_setup.conn, table_name) - - vectorDB = HanaDB( - connection=test_setup.conn, - embedding=embedding, - table_name=table_name, - ) - - vectorDB.add_documents(DOCUMENTS) - - docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) - ids = [doc.metadata["id"] for doc in docs] - assert len(ids) == len(expected_ids), test_filter - assert set(ids).issubset(expected_ids), test_filter - - -@pytest.mark.parametrize("test_filter, expected_ids", TYPE_4_FILTERING_TEST_CASES) -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_pgvector_with_with_metadata_filters_4( - test_filter: Dict[str, Any], - expected_ids: List[int], -) -> None: - table_name = "TEST_TABLE_ENHANCED_FILTER_4" - drop_table(test_setup.conn, table_name) - - vectorDB = HanaDB( - connection=test_setup.conn, - embedding=embedding, - table_name=table_name, - ) - - vectorDB.add_documents(DOCUMENTS) - - docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) - ids = [doc.metadata["id"] for doc in docs] - assert len(ids) == len(expected_ids), test_filter - assert set(ids).issubset(expected_ids), test_filter - - -@pytest.mark.parametrize("test_filter, expected_ids", TYPE_4B_FILTERING_TEST_CASES) -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_pgvector_with_with_metadata_filters_4b( - test_filter: Dict[str, Any], - expected_ids: List[int], -) -> None: - table_name = "TEST_TABLE_ENHANCED_FILTER_4B" - drop_table(test_setup.conn, table_name) - - vectorDB = HanaDB( - connection=test_setup.conn, - embedding=embedding, - table_name=table_name, - ) - - vectorDB.add_documents(DOCUMENTS) - - docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) - ids = [doc.metadata["id"] for doc in docs] - assert len(ids) == len(expected_ids), test_filter - assert set(ids).issubset(expected_ids), test_filter - - -@pytest.mark.parametrize("test_filter, expected_ids", TYPE_5_FILTERING_TEST_CASES) -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_pgvector_with_with_metadata_filters_5( - test_filter: Dict[str, Any], - expected_ids: List[int], -) -> None: - table_name = "TEST_TABLE_ENHANCED_FILTER_5" - drop_table(test_setup.conn, table_name) - - vectorDB = HanaDB( - connection=test_setup.conn, - embedding=embedding, - table_name=table_name, - ) - - vectorDB.add_documents(DOCUMENTS) - - docs = vectorDB.similarity_search("meow", k=5, filter=test_filter) - ids = [doc.metadata["id"] for doc in docs] - assert len(ids) == len(expected_ids), test_filter - assert set(ids).issubset(expected_ids), test_filter - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_preexisting_specific_columns_for_metadata_fill( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "PREEXISTING_FILTER_COLUMNS" - # drop_table(test_setup.conn, table_name) - - sql_str = ( - f'CREATE TABLE "{table_name}" (' - f'"VEC_TEXT" NCLOB, ' - f'"VEC_META" NCLOB, ' - f'"VEC_VECTOR" REAL_VECTOR, ' - f'"Owner" NVARCHAR(100), ' - f'"quality" NVARCHAR(100));' - ) - try: - cur = test_setup.conn.cursor() - cur.execute(sql_str) - finally: - cur.close() - - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=metadatas, - embedding=embedding, - table_name=table_name, - specific_metadata_columns=["Owner", "quality"], - ) - - c = 0 - try: - sql_str = f"SELECT COUNT(*) FROM {table_name} WHERE \"quality\"='ugly'" - cur = test_setup.conn.cursor() - cur.execute(sql_str) - if cur.has_result_set(): - rows = cur.fetchall() - c = rows[0][0] - finally: - cur.close() - assert c == 3 - - docs = vectorDB.similarity_search("hello", k=5, filter={"quality": "good"}) - assert len(docs) == 1 - assert docs[0].page_content == "foo" - - docs = vectorDB.similarity_search("hello", k=5, filter={"start": 100}) - assert len(docs) == 1 - assert docs[0].page_content == "bar" - - docs = vectorDB.similarity_search( - "hello", k=5, filter={"start": 100, "quality": "good"} - ) - assert len(docs) == 0 - - docs = vectorDB.similarity_search( - "hello", k=5, filter={"start": 0, "quality": "good"} - ) - assert len(docs) == 1 - assert docs[0].page_content == "foo" - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_preexisting_specific_columns_for_metadata_via_array( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "PREEXISTING_FILTER_COLUMNS_VIA_ARRAY" - # drop_table(test_setup.conn, table_name) - - sql_str = ( - f'CREATE TABLE "{table_name}" (' - f'"VEC_TEXT" NCLOB, ' - f'"VEC_META" NCLOB, ' - f'"VEC_VECTOR" REAL_VECTOR, ' - f'"Owner" NVARCHAR(100), ' - f'"quality" NVARCHAR(100));' - ) - try: - cur = test_setup.conn.cursor() - cur.execute(sql_str) - finally: - cur.close() - - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=metadatas, - embedding=embedding, - table_name=table_name, - specific_metadata_columns=["quality"], - ) - - c = 0 - try: - sql_str = f"SELECT COUNT(*) FROM {table_name} WHERE \"quality\"='ugly'" - cur = test_setup.conn.cursor() - cur.execute(sql_str) - if cur.has_result_set(): - rows = cur.fetchall() - c = rows[0][0] - finally: - cur.close() - assert c == 3 - - try: - sql_str = f"SELECT COUNT(*) FROM {table_name} WHERE \"Owner\"='Steve'" - cur = test_setup.conn.cursor() - cur.execute(sql_str) - if cur.has_result_set(): - rows = cur.fetchall() - c = rows[0][0] - finally: - cur.close() - assert c == 0 - - docs = vectorDB.similarity_search("hello", k=5, filter={"quality": "good"}) - assert len(docs) == 1 - assert docs[0].page_content == "foo" - - docs = vectorDB.similarity_search("hello", k=5, filter={"start": 100}) - assert len(docs) == 1 - assert docs[0].page_content == "bar" - - docs = vectorDB.similarity_search( - "hello", k=5, filter={"start": 100, "quality": "good"} - ) - assert len(docs) == 0 - - docs = vectorDB.similarity_search( - "hello", k=5, filter={"start": 0, "quality": "good"} - ) - assert len(docs) == 1 - assert docs[0].page_content == "foo" - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_preexisting_specific_columns_for_metadata_multiple_columns( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "PREEXISTING_FILTER_MULTIPLE_COLUMNS" - # drop_table(test_setup.conn, table_name) - - sql_str = ( - f'CREATE TABLE "{table_name}" (' - f'"VEC_TEXT" NCLOB, ' - f'"VEC_META" NCLOB, ' - f'"VEC_VECTOR" REAL_VECTOR, ' - f'"quality" NVARCHAR(100), ' - f'"start" INTEGER);' - ) - try: - cur = test_setup.conn.cursor() - cur.execute(sql_str) - finally: - cur.close() - - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=metadatas, - embedding=embedding, - table_name=table_name, - specific_metadata_columns=["quality", "start"], - ) - - docs = vectorDB.similarity_search("hello", k=5, filter={"quality": "good"}) - assert len(docs) == 1 - assert docs[0].page_content == "foo" - - docs = vectorDB.similarity_search("hello", k=5, filter={"start": 100}) - assert len(docs) == 1 - assert docs[0].page_content == "bar" - - docs = vectorDB.similarity_search( - "hello", k=5, filter={"start": 100, "quality": "good"} - ) - assert len(docs) == 0 - - docs = vectorDB.similarity_search( - "hello", k=5, filter={"start": 0, "quality": "good"} - ) - assert len(docs) == 1 - assert docs[0].page_content == "foo" - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_preexisting_specific_columns_for_metadata_empty_columns( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "PREEXISTING_FILTER_MULTIPLE_COLUMNS_EMPTY" - # drop_table(test_setup.conn, table_name) - - sql_str = ( - f'CREATE TABLE "{table_name}" (' - f'"VEC_TEXT" NCLOB, ' - f'"VEC_META" NCLOB, ' - f'"VEC_VECTOR" REAL_VECTOR, ' - f'"quality" NVARCHAR(100), ' - f'"ready" BOOLEAN, ' - f'"start" INTEGER);' - ) - try: - cur = test_setup.conn.cursor() - cur.execute(sql_str) - finally: - cur.close() - - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=metadatas, - embedding=embedding, - table_name=table_name, - specific_metadata_columns=["quality", "ready", "start"], - ) - - docs = vectorDB.similarity_search("hello", k=5, filter={"quality": "good"}) - assert len(docs) == 1 - assert docs[0].page_content == "foo" - - docs = vectorDB.similarity_search("hello", k=5, filter={"start": 100}) - assert len(docs) == 1 - assert docs[0].page_content == "bar" - - docs = vectorDB.similarity_search( - "hello", k=5, filter={"start": 100, "quality": "good"} - ) - assert len(docs) == 0 - - docs = vectorDB.similarity_search( - "hello", k=5, filter={"start": 0, "quality": "good"} - ) - assert len(docs) == 1 - assert docs[0].page_content == "foo" - - docs = vectorDB.similarity_search("hello", k=5, filter={"ready": True}) - assert len(docs) == 3 - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_preexisting_specific_columns_for_metadata_wrong_type_or_non_existing( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "PREEXISTING_FILTER_COLUMNS_WRONG_TYPE" - # drop_table(test_setup.conn, table_name) - - sql_str = ( - f'CREATE TABLE "{table_name}" (' - f'"VEC_TEXT" NCLOB, ' - f'"VEC_META" NCLOB, ' - f'"VEC_VECTOR" REAL_VECTOR, ' - f'"quality" INTEGER); ' - ) - try: - cur = test_setup.conn.cursor() - cur.execute(sql_str) - finally: - cur.close() - - # Check if table is created - exception_occurred = False - try: - HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=metadatas, - embedding=embedding, - table_name=table_name, - specific_metadata_columns=["quality"], - ) - exception_occurred = False - except dbapi.Error: # Nothing we should do here, hdbcli will throw an error - exception_occurred = True - assert exception_occurred # Check if table is created - - exception_occurred = False - try: - HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=metadatas, - embedding=embedding, - table_name=table_name, - specific_metadata_columns=["NonExistingColumn"], - ) - exception_occurred = False - except AttributeError: # Nothing we should do here, hdbcli will throw an error - exception_occurred = True - assert exception_occurred - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_preexisting_specific_columns_for_returned_metadata_completeness( - texts: List[str], metadatas: List[dict] -) -> None: - table_name = "PREEXISTING_FILTER_COLUMNS_METADATA_COMPLETENESS" - # drop_table(test_setup.conn, table_name) - - sql_str = ( - f'CREATE TABLE "{table_name}" (' - f'"VEC_TEXT" NCLOB, ' - f'"VEC_META" NCLOB, ' - f'"VEC_VECTOR" REAL_VECTOR, ' - f'"quality" NVARCHAR(100), ' - f'"NonExisting" NVARCHAR(100), ' - f'"ready" BOOLEAN, ' - f'"start" INTEGER);' - ) - try: - cur = test_setup.conn.cursor() - cur.execute(sql_str) - finally: - cur.close() - - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - metadatas=metadatas, - embedding=embedding, - table_name=table_name, - specific_metadata_columns=["quality", "ready", "start", "NonExisting"], - ) - - docs = vectorDB.similarity_search("hello", k=5, filter={"quality": "good"}) - assert len(docs) == 1 - assert docs[0].page_content == "foo" - assert docs[0].metadata["end"] == 100 - assert docs[0].metadata["start"] == 0 - assert docs[0].metadata["quality"] == "good" - assert docs[0].metadata["ready"] - assert "NonExisting" not in docs[0].metadata.keys() - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_create_hnsw_index_with_default_values(texts: List[str]) -> None: - table_name = "TEST_TABLE_HNSW_INDEX_DEFAULT" - - # Delete table if it exists (cleanup from previous tests) - drop_table(test_setup.conn, table_name) - - # Create table and insert data - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - embedding=embedding, - table_name=table_name, - ) - - # Test the creation of HNSW index - try: - vectorDB.create_hnsw_index() - except Exception as e: - pytest.fail(f"Failed to create HNSW index: {e}") - - # Perform a search using the index to confirm its correctness - search_result = vectorDB.max_marginal_relevance_search(texts[0], k=2, fetch_k=20) - - assert len(search_result) == 2 - assert search_result[0].page_content == texts[0] - assert search_result[1].page_content != texts[0] - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_create_hnsw_index_with_defined_values(texts: List[str]) -> None: - table_name = "TEST_TABLE_HNSW_INDEX_DEFINED" - - # Delete table if it exists (cleanup from previous tests) - drop_table(test_setup.conn, table_name) - - # Create table and insert data - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - embedding=embedding, - table_name=table_name, - distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE, - ) - - # Test the creation of HNSW index with specific values - try: - vectorDB.create_hnsw_index( - index_name="my_L2_index", ef_search=500, m=100, ef_construction=200 - ) - except Exception as e: - pytest.fail(f"Failed to create HNSW index with defined values: {e}") - - # Perform a search using the index to confirm its correctness - search_result = vectorDB.max_marginal_relevance_search(texts[0], k=2, fetch_k=20) - - assert len(search_result) == 2 - assert search_result[0].page_content == texts[0] - assert search_result[1].page_content != texts[0] - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_create_hnsw_index_after_initialization(texts: List[str]) -> None: - table_name = "TEST_TABLE_HNSW_INDEX_AFTER_INIT" - - drop_table(test_setup.conn, table_name) - - # Initialize HanaDB without adding documents yet - vectorDB = HanaDB( - connection=test_setup.conn, - embedding=embedding, - table_name=table_name, - ) - - # Create HNSW index before adding documents - vectorDB.create_hnsw_index( - index_name="index_pre_add", ef_search=400, m=50, ef_construction=150 - ) - - # Add texts after index creation - vectorDB.add_texts(texts=texts) - - # Perform similarity search using the index - search_result = vectorDB.similarity_search(texts[0], k=3) - - # Assert that search result is valid and has expected length - assert len(search_result) == 3 - assert search_result[0].page_content == texts[0] - assert search_result[1].page_content != texts[0] - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_duplicate_hnsw_index_creation(texts: List[str]) -> None: - table_name = "TEST_TABLE_HNSW_DUPLICATE_INDEX" - - # Delete table if it exists (cleanup from previous tests) - drop_table(test_setup.conn, table_name) - - # Create table and insert data - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - embedding=embedding, - table_name=table_name, - ) - - # Create HNSW index for the first time - vectorDB.create_hnsw_index( - index_name="index_cosine", - ef_search=300, - m=80, - ef_construction=100, - ) - - with pytest.raises(Exception): - vectorDB.create_hnsw_index(ef_search=300, m=80, ef_construction=100) - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_create_hnsw_index_invalid_m_value(texts: List[str]) -> None: - table_name = "TEST_TABLE_HNSW_INVALID_M" - - # Cleanup: drop the table if it exists - drop_table(test_setup.conn, table_name) - - # Create table and insert data - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - embedding=embedding, - table_name=table_name, - ) - - # Test invalid `m` value (too low) - with pytest.raises(ValueError): - vectorDB.create_hnsw_index(m=3) - - # Test invalid `m` value (too high) - with pytest.raises(ValueError): - vectorDB.create_hnsw_index(m=1001) - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_create_hnsw_index_invalid_ef_construction(texts: List[str]) -> None: - table_name = "TEST_TABLE_HNSW_INVALID_EF_CONSTRUCTION" - - # Cleanup: drop the table if it exists - drop_table(test_setup.conn, table_name) - - # Create table and insert data - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - embedding=embedding, - table_name=table_name, - ) - - # Test invalid `ef_construction` value (too low) - with pytest.raises(ValueError): - vectorDB.create_hnsw_index(ef_construction=0) - - # Test invalid `ef_construction` value (too high) - with pytest.raises(ValueError): - vectorDB.create_hnsw_index(ef_construction=100001) - - -@pytest.mark.skipif(not hanadb_installed, reason="hanadb not installed") -def test_create_hnsw_index_invalid_ef_search(texts: List[str]) -> None: - table_name = "TEST_TABLE_HNSW_INVALID_EF_SEARCH" - - # Cleanup: drop the table if it exists - drop_table(test_setup.conn, table_name) - - # Create table and insert data - vectorDB = HanaDB.from_texts( - connection=test_setup.conn, - texts=texts, - embedding=embedding, - table_name=table_name, - ) - - # Test invalid `ef_search` value (too low) - with pytest.raises(ValueError): - vectorDB.create_hnsw_index(ef_search=0) - - # Test invalid `ef_search` value (too high) - with pytest.raises(ValueError): - vectorDB.create_hnsw_index(ef_search=100001) diff --git a/libs/community/tests/integration_tests/vectorstores/test_milvus.py b/libs/community/tests/integration_tests/vectorstores/test_milvus.py deleted file mode 100644 index c5802d1eb..000000000 --- a/libs/community/tests/integration_tests/vectorstores/test_milvus.py +++ /dev/null @@ -1,159 +0,0 @@ -"""Test Milvus functionality.""" - -from typing import Any, List, Optional - -from langchain_core.documents import Document - -from langchain_community.vectorstores import Milvus -from tests.integration_tests.vectorstores.fake_embeddings import ( - FakeEmbeddings, - fake_texts, -) - - -def _milvus_from_texts( - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - drop: bool = True, -) -> Milvus: - return Milvus.from_texts( - fake_texts, - FakeEmbeddings(), - metadatas=metadatas, - ids=ids, - connection_args={"host": "127.0.0.1", "port": "19530"}, - drop_old=drop, - ) - - -def _get_pks(expr: str, docsearch: Milvus) -> List[Any]: - return docsearch.get_pks(expr) # type: ignore[return-value] - - -def test_milvus() -> None: - """Test end to end construction and search.""" - docsearch = _milvus_from_texts() - output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo")] - - -def test_milvus_with_metadata() -> None: - """Test with metadata""" - docsearch = _milvus_from_texts(metadatas=[{"label": "test"}] * len(fake_texts)) - output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo", metadata={"label": "test"})] - - -def test_milvus_with_id() -> None: - """Test with ids""" - ids = ["id_" + str(i) for i in range(len(fake_texts))] - docsearch = _milvus_from_texts(ids=ids) - output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo")] - - output = docsearch.delete(ids=ids) - assert output.delete_count == len(fake_texts) - - try: - ids = ["dup_id" for _ in fake_texts] - _milvus_from_texts(ids=ids) - except Exception as e: - assert isinstance(e, AssertionError) - - -def test_milvus_with_score() -> None: - """Test end to end construction and search with scores and IDs.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = _milvus_from_texts(metadatas=metadatas) - output = docsearch.similarity_search_with_score("foo", k=3) - docs = [o[0] for o in output] - scores = [o[1] for o in output] - assert docs == [ - Document(page_content="foo", metadata={"page": 0}), - Document(page_content="bar", metadata={"page": 1}), - Document(page_content="baz", metadata={"page": 2}), - ] - assert scores[0] < scores[1] < scores[2] - - -def test_milvus_max_marginal_relevance_search() -> None: - """Test end to end construction and MRR search.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = _milvus_from_texts(metadatas=metadatas) - output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3) - assert output == [ - Document(page_content="foo", metadata={"page": 0}), - Document(page_content="baz", metadata={"page": 2}), - ] - - -def test_milvus_add_extra() -> None: - """Test end to end construction and MRR search.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = _milvus_from_texts(metadatas=metadatas) - - docsearch.add_texts(texts, metadatas) - - output = docsearch.similarity_search("foo", k=10) - assert len(output) == 6 - - -def test_milvus_no_drop() -> None: - """Test end to end construction and MRR search.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = _milvus_from_texts(metadatas=metadatas) - del docsearch - - docsearch = _milvus_from_texts(metadatas=metadatas, drop=False) - - output = docsearch.similarity_search("foo", k=10) - assert len(output) == 6 - - -def test_milvus_get_pks() -> None: - """Test end to end construction and get pks with expr""" - texts = ["foo", "bar", "baz"] - metadatas = [{"id": i} for i in range(len(texts))] - docsearch = _milvus_from_texts(metadatas=metadatas) - expr = "id in [1,2]" - output = _get_pks(expr, docsearch) - assert len(output) == 2 - - -def test_milvus_delete_entities() -> None: - """Test end to end construction and delete entities""" - texts = ["foo", "bar", "baz"] - metadatas = [{"id": i} for i in range(len(texts))] - docsearch = _milvus_from_texts(metadatas=metadatas) - expr = "id in [1,2]" - pks = _get_pks(expr, docsearch) - result = docsearch.delete(pks) - assert result is True - - -def test_milvus_upsert_entities() -> None: - """Test end to end construction and upsert entities""" - texts = ["foo", "bar", "baz"] - metadatas = [{"id": i} for i in range(len(texts))] - docsearch = _milvus_from_texts(metadatas=metadatas) - expr = "id in [1,2]" - pks = _get_pks(expr, docsearch) - documents = [ - Document(page_content="test_1", metadata={"id": 1}), - Document(page_content="test_2", metadata={"id": 3}), - ] - ids = docsearch.upsert(pks, documents) - assert len(ids) == 2 # type: ignore[arg-type] - - -# if __name__ == "__main__": -# test_milvus() -# test_milvus_with_metadata() -# test_milvus_with_score() -# test_milvus_max_marginal_relevance_search() -# test_milvus_add_extra() -# test_milvus_no_drop() diff --git a/libs/community/tests/integration_tests/vectorstores/test_mongodb_atlas.py b/libs/community/tests/integration_tests/vectorstores/test_mongodb_atlas.py deleted file mode 100644 index fdf7747e1..000000000 --- a/libs/community/tests/integration_tests/vectorstores/test_mongodb_atlas.py +++ /dev/null @@ -1,148 +0,0 @@ -"""Test MongoDB Atlas Vector Search functionality.""" - -from __future__ import annotations - -import os -from time import sleep -from typing import Any - -import pytest -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings - -from langchain_community.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch - -INDEX_NAME = "langchain-test-index" -NAMESPACE = "langchain_test_db.langchain_test_collection" -CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI") -DB_NAME, COLLECTION_NAME = NAMESPACE.split(".") - - -def get_collection() -> Any: - from pymongo import MongoClient - - test_client: MongoClient = MongoClient(CONNECTION_STRING) - return test_client[DB_NAME][COLLECTION_NAME] - - -@pytest.fixture() -def collection() -> Any: - return get_collection() - - -class TestMongoDBAtlasVectorSearch: - @classmethod - def setup_class(cls) -> None: - # ensure the test collection is empty - collection = get_collection() - assert collection.count_documents({}) == 0 - - @classmethod - def teardown_class(cls) -> None: - collection = get_collection() - # delete all the documents in the collection - collection.delete_many({}) - - @pytest.fixture(autouse=True) - def setup(self) -> None: - collection = get_collection() - # delete all the documents in the collection - collection.delete_many({}) - - def test_from_documents( - self, embedding_openai: Embeddings, collection: Any - ) -> None: - """Test end to end construction and search.""" - documents = [ - Document(page_content="Dogs are tough.", metadata={"a": 1}), - Document(page_content="Cats have fluff.", metadata={"b": 1}), - Document(page_content="What is a sandwich?", metadata={"c": 1}), - Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}), - ] - vectorstore = MongoDBAtlasVectorSearch.from_documents( - documents, - embedding_openai, - collection=collection, - index_name=INDEX_NAME, - ) - sleep(1) # waits for mongot to update Lucene's index - output = vectorstore.similarity_search("Sandwich", k=1) - assert output[0].page_content == "What is a sandwich?" - assert output[0].metadata["c"] == 1 - - def test_from_texts(self, embedding_openai: Embeddings, collection: Any) -> None: - texts = [ - "Dogs are tough.", - "Cats have fluff.", - "What is a sandwich?", - "That fence is purple.", - ] - vectorstore = MongoDBAtlasVectorSearch.from_texts( - texts, - embedding_openai, - collection=collection, - index_name=INDEX_NAME, - ) - sleep(1) # waits for mongot to update Lucene's index - output = vectorstore.similarity_search("Sandwich", k=1) - assert output[0].page_content == "What is a sandwich?" - - def test_from_texts_with_metadatas( - self, embedding_openai: Embeddings, collection: Any - ) -> None: - texts = [ - "Dogs are tough.", - "Cats have fluff.", - "What is a sandwich?", - "The fence is purple.", - ] - metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}] - vectorstore = MongoDBAtlasVectorSearch.from_texts( - texts, - embedding_openai, - metadatas=metadatas, - collection=collection, - index_name=INDEX_NAME, - ) - sleep(1) # waits for mongot to update Lucene's index - output = vectorstore.similarity_search("Sandwich", k=1) - assert output[0].page_content == "What is a sandwich?" - assert output[0].metadata["c"] == 1 - - def test_from_texts_with_metadatas_and_pre_filter( - self, embedding_openai: Embeddings, collection: Any - ) -> None: - texts = [ - "Dogs are tough.", - "Cats have fluff.", - "What is a sandwich?", - "The fence is purple.", - ] - metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}] - vectorstore = MongoDBAtlasVectorSearch.from_texts( - texts, - embedding_openai, - metadatas=metadatas, - collection=collection, - index_name=INDEX_NAME, - ) - sleep(1) # waits for mongot to update Lucene's index - output = vectorstore.similarity_search( - "Sandwich", k=1, pre_filter={"range": {"lte": 0, "path": "c"}} - ) - assert output == [] - - def test_mmr(self, embedding_openai: Embeddings, collection: Any) -> None: - texts = ["foo", "foo", "fou", "foy"] - vectorstore = MongoDBAtlasVectorSearch.from_texts( - texts, - embedding_openai, - collection=collection, - index_name=INDEX_NAME, - ) - sleep(1) # waits for mongot to update Lucene's index - query = "foo" - output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1) - assert len(output) == len(texts) - assert output[0].page_content == "foo" - assert output[1].page_content != "foo" diff --git a/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py b/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py deleted file mode 100644 index ad0fcaacc..000000000 --- a/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py +++ /dev/null @@ -1,987 +0,0 @@ -"""Test Neo4jVector functionality.""" - -import os -from math import isclose -from typing import Any, Dict, List, cast - -from langchain_core.documents import Document -from yaml import safe_load - -from langchain_community.graphs import Neo4jGraph -from langchain_community.vectorstores.neo4j_vector import ( - Neo4jVector, - SearchType, - _get_search_index_query, -) -from langchain_community.vectorstores.utils import DistanceStrategy -from tests.integration_tests.vectorstores.fake_embeddings import ( - AngularTwoDimensionalEmbeddings, - FakeEmbeddings, -) -from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import ( - DOCUMENTS, - TYPE_1_FILTERING_TEST_CASES, - TYPE_2_FILTERING_TEST_CASES, - TYPE_3_FILTERING_TEST_CASES, - TYPE_4_FILTERING_TEST_CASES, -) - -url = os.environ.get("NEO4J_URL", "bolt://localhost:7687") -username = os.environ.get("NEO4J_USERNAME", "neo4j") -password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein") - -OS_TOKEN_COUNT = 1536 - -texts = ["foo", "bar", "baz", "It is the end of the world. Take shelter!"] - -""" -cd tests/integration_tests/vectorstores/docker-compose -docker-compose -f neo4j.yml up -""" - - -def drop_vector_indexes(store: Neo4jVector) -> None: - """Cleanup all vector indexes""" - all_indexes = store.query( - """ - SHOW INDEXES YIELD name, type - WHERE type IN ["VECTOR", "FULLTEXT"] - RETURN name - """ - ) - for index in all_indexes: - store.query(f"DROP INDEX `{index['name']}`") - - store.query("MATCH (n) DETACH DELETE n;") - - -class FakeEmbeddingsWithOsDimension(FakeEmbeddings): - """Fake embeddings functionality for testing.""" - - def embed_documents(self, embedding_texts: List[str]) -> List[List[float]]: - """Return simple embeddings.""" - return [ - [float(1.0)] * (OS_TOKEN_COUNT - 1) + [float(i + 1)] - for i in range(len(embedding_texts)) - ] - - def embed_query(self, text: str) -> List[float]: - """Return simple embeddings.""" - return [float(1.0)] * (OS_TOKEN_COUNT - 1) + [float(texts.index(text) + 1)] - - -def test_neo4jvector() -> None: - """Test end to end construction and search.""" - docsearch = Neo4jVector.from_texts( - texts=texts, - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - pre_delete_collection=True, - ) - output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo")] - - drop_vector_indexes(docsearch) - - -def test_neo4jvector_euclidean() -> None: - """Test euclidean distance""" - docsearch = Neo4jVector.from_texts( - texts=texts, - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - pre_delete_collection=True, - distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE, - ) - output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo")] - - drop_vector_indexes(docsearch) - - -def test_neo4jvector_embeddings() -> None: - """Test end to end construction with embeddings and search.""" - text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts) - text_embedding_pairs = list(zip(texts, text_embeddings)) - docsearch = Neo4jVector.from_embeddings( - text_embeddings=text_embedding_pairs, - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - pre_delete_collection=True, - ) - output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo")] - - drop_vector_indexes(docsearch) - - -def test_neo4jvector_catch_wrong_index_name() -> None: - """Test if index name is misspelled, but node label and property are correct.""" - text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts) - text_embedding_pairs = list(zip(texts, text_embeddings)) - Neo4jVector.from_embeddings( - text_embeddings=text_embedding_pairs, - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - pre_delete_collection=True, - ) - existing = Neo4jVector.from_existing_index( - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - index_name="test", - ) - output = existing.similarity_search("foo", k=1) - assert output == [Document(page_content="foo")] - - drop_vector_indexes(existing) - - -def test_neo4jvector_catch_wrong_node_label() -> None: - """Test if node label is misspelled, but index name is correct.""" - text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts) - text_embedding_pairs = list(zip(texts, text_embeddings)) - Neo4jVector.from_embeddings( - text_embeddings=text_embedding_pairs, - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - pre_delete_collection=True, - ) - existing = Neo4jVector.from_existing_index( - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - index_name="vector", - node_label="test", - ) - output = existing.similarity_search("foo", k=1) - assert output == [Document(page_content="foo")] - - drop_vector_indexes(existing) - - -def test_neo4jvector_with_metadatas() -> None: - """Test end to end construction and search.""" - metadatas = [{"page": str(i)} for i in range(len(texts))] - docsearch = Neo4jVector.from_texts( - texts=texts, - embedding=FakeEmbeddingsWithOsDimension(), - metadatas=metadatas, - url=url, - username=username, - password=password, - pre_delete_collection=True, - ) - output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo", metadata={"page": "0"})] - - drop_vector_indexes(docsearch) - - -def test_neo4jvector_with_metadatas_with_scores() -> None: - """Test end to end construction and search.""" - metadatas = [{"page": str(i)} for i in range(len(texts))] - docsearch = Neo4jVector.from_texts( - texts=texts, - embedding=FakeEmbeddingsWithOsDimension(), - metadatas=metadatas, - url=url, - username=username, - password=password, - pre_delete_collection=True, - ) - output = [ - (doc, round(score, 1)) - for doc, score in docsearch.similarity_search_with_score("foo", k=1) - ] - assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)] - - drop_vector_indexes(docsearch) - - -def test_neo4jvector_relevance_score() -> None: - """Test to make sure the relevance score is scaled to 0-1.""" - metadatas = [{"page": str(i)} for i in range(len(texts))] - docsearch = Neo4jVector.from_texts( - texts=texts, - embedding=FakeEmbeddingsWithOsDimension(), - metadatas=metadatas, - url=url, - username=username, - password=password, - pre_delete_collection=True, - ) - - output = docsearch.similarity_search_with_relevance_scores("foo", k=3) - expected_output = [ - (Document(page_content="foo", metadata={"page": "0"}), 1.0), - (Document(page_content="bar", metadata={"page": "1"}), 0.9998376369476318), - (Document(page_content="baz", metadata={"page": "2"}), 0.9993523359298706), - ] - - # Check if the length of the outputs matches - assert len(output) == len(expected_output) - - # Check if each document and its relevance score is close to the expected value - for (doc, score), (expected_doc, expected_score) in zip(output, expected_output): - assert doc.page_content == expected_doc.page_content - assert doc.metadata == expected_doc.metadata - assert isclose(score, expected_score, rel_tol=1e-5) - - drop_vector_indexes(docsearch) - - -def test_neo4jvector_retriever_search_threshold() -> None: - """Test using retriever for searching with threshold.""" - metadatas = [{"page": str(i)} for i in range(len(texts))] - docsearch = Neo4jVector.from_texts( - texts=texts, - embedding=FakeEmbeddingsWithOsDimension(), - metadatas=metadatas, - url=url, - username=username, - password=password, - pre_delete_collection=True, - ) - - retriever = docsearch.as_retriever( - search_type="similarity_score_threshold", - search_kwargs={"k": 3, "score_threshold": 0.9999}, - ) - output = retriever.invoke("foo") - assert output == [ - Document(page_content="foo", metadata={"page": "0"}), - ] - - drop_vector_indexes(docsearch) - - -def test_custom_return_neo4jvector() -> None: - """Test end to end construction and search.""" - docsearch = Neo4jVector.from_texts( - texts=["test"], - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - pre_delete_collection=True, - retrieval_query="RETURN 'foo' AS text, score, {test: 'test'} AS metadata", - ) - output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo", metadata={"test": "test"})] - - drop_vector_indexes(docsearch) - - -def test_neo4jvector_prefer_indexname() -> None: - """Test using when two indexes are found, prefer by index_name.""" - Neo4jVector.from_texts( - texts=["foo"], - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - pre_delete_collection=True, - ) - - Neo4jVector.from_texts( - texts=["bar"], - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - index_name="foo", - node_label="Test", - embedding_node_property="vector", - text_node_property="info", - pre_delete_collection=True, - ) - - existing_index = Neo4jVector.from_existing_index( - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - index_name="foo", - text_node_property="info", - ) - - output = existing_index.similarity_search("bar", k=1) - assert output == [Document(page_content="bar", metadata={})] - drop_vector_indexes(existing_index) - - -def test_neo4jvector_prefer_indexname_insert() -> None: - """Test using when two indexes are found, prefer by index_name.""" - Neo4jVector.from_texts( - texts=["baz"], - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - pre_delete_collection=True, - ) - - Neo4jVector.from_texts( - texts=["foo"], - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - index_name="foo", - node_label="Test", - embedding_node_property="vector", - text_node_property="info", - pre_delete_collection=True, - ) - - existing_index = Neo4jVector.from_existing_index( - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - index_name="foo", - text_node_property="info", - ) - - existing_index.add_documents([Document(page_content="bar", metadata={})]) - - output = existing_index.similarity_search("bar", k=2) - assert output == [ - Document(page_content="bar", metadata={}), - Document(page_content="foo", metadata={}), - ] - drop_vector_indexes(existing_index) - - -def test_neo4jvector_hybrid() -> None: - """Test end to end construction with hybrid search.""" - text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts) - text_embedding_pairs = list(zip(texts, text_embeddings)) - docsearch = Neo4jVector.from_embeddings( - text_embeddings=text_embedding_pairs, - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - pre_delete_collection=True, - search_type=SearchType.HYBRID, - ) - output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo")] - - drop_vector_indexes(docsearch) - - -def test_neo4jvector_hybrid_deduplicate() -> None: - """Test result deduplication with hybrid search.""" - text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts) - text_embedding_pairs = list(zip(texts, text_embeddings)) - docsearch = Neo4jVector.from_embeddings( - text_embeddings=text_embedding_pairs, - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - pre_delete_collection=True, - search_type=SearchType.HYBRID, - ) - output = docsearch.similarity_search("foo", k=3) - assert output == [ - Document(page_content="foo"), - Document(page_content="bar"), - Document(page_content="baz"), - ] - - drop_vector_indexes(docsearch) - - -def test_neo4jvector_hybrid_retrieval_query() -> None: - """Test custom retrieval_query with hybrid search.""" - text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts) - text_embedding_pairs = list(zip(texts, text_embeddings)) - docsearch = Neo4jVector.from_embeddings( - text_embeddings=text_embedding_pairs, - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - pre_delete_collection=True, - search_type=SearchType.HYBRID, - retrieval_query="RETURN 'moo' AS text, score, {test: 'test'} AS metadata", - ) - output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="moo", metadata={"test": "test"})] - - drop_vector_indexes(docsearch) - - -def test_neo4jvector_hybrid_retrieval_query2() -> None: - """Test custom retrieval_query with hybrid search.""" - text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts) - text_embedding_pairs = list(zip(texts, text_embeddings)) - docsearch = Neo4jVector.from_embeddings( - text_embeddings=text_embedding_pairs, - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - pre_delete_collection=True, - search_type=SearchType.HYBRID, - retrieval_query="RETURN node.text AS text, score, {test: 'test'} AS metadata", - ) - output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo", metadata={"test": "test"})] - - drop_vector_indexes(docsearch) - - -def test_neo4jvector_missing_keyword() -> None: - """Test hybrid search with missing keyword_index_search.""" - text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts) - text_embedding_pairs = list(zip(texts, text_embeddings)) - docsearch = Neo4jVector.from_embeddings( - text_embeddings=text_embedding_pairs, - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - pre_delete_collection=True, - ) - try: - Neo4jVector.from_existing_index( - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - index_name="vector", - search_type=SearchType.HYBRID, - ) - except ValueError as e: - assert str(e) == ( - "keyword_index name has to be specified when using hybrid search option" - ) - drop_vector_indexes(docsearch) - - -def test_neo4jvector_hybrid_from_existing() -> None: - """Test hybrid search with missing keyword_index_search.""" - text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts) - text_embedding_pairs = list(zip(texts, text_embeddings)) - Neo4jVector.from_embeddings( - text_embeddings=text_embedding_pairs, - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - pre_delete_collection=True, - search_type=SearchType.HYBRID, - ) - existing = Neo4jVector.from_existing_index( - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - index_name="vector", - keyword_index_name="keyword", - search_type=SearchType.HYBRID, - ) - - output = existing.similarity_search("foo", k=1) - assert output == [Document(page_content="foo")] - - drop_vector_indexes(existing) - - -def test_neo4jvector_from_existing_graph() -> None: - """Test from_existing_graph with a single property.""" - graph = Neo4jVector.from_texts( - texts=["test"], - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - index_name="foo", - node_label="Foo", - embedding_node_property="vector", - text_node_property="info", - pre_delete_collection=True, - ) - - graph.query("MATCH (n) DETACH DELETE n") - - graph.query("CREATE (:Test {name:'Foo'}),(:Test {name:'Bar'})") - - existing = Neo4jVector.from_existing_graph( - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - index_name="vector", - node_label="Test", - text_node_properties=["name"], - embedding_node_property="embedding", - ) - - output = existing.similarity_search("foo", k=1) - assert output == [Document(page_content="\nname: Foo")] - - drop_vector_indexes(existing) - - -def test_neo4jvector_from_existing_graph_hybrid() -> None: - """Test from_existing_graph hybrid with a single property.""" - graph = Neo4jVector.from_texts( - texts=["test"], - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - index_name="foo", - node_label="Foo", - embedding_node_property="vector", - text_node_property="info", - pre_delete_collection=True, - ) - - graph.query("MATCH (n) DETACH DELETE n") - - graph.query("CREATE (:Test {name:'foo'}),(:Test {name:'Bar'})") - - existing = Neo4jVector.from_existing_graph( - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - index_name="vector", - node_label="Test", - text_node_properties=["name"], - embedding_node_property="embedding", - search_type=SearchType.HYBRID, - ) - - output = existing.similarity_search("foo", k=1) - assert output == [Document(page_content="\nname: foo")] - - drop_vector_indexes(existing) - - -def test_neo4jvector_from_existing_graph_multiple_properties() -> None: - """Test from_existing_graph with a two property.""" - graph = Neo4jVector.from_texts( - texts=["test"], - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - index_name="foo", - node_label="Foo", - embedding_node_property="vector", - text_node_property="info", - pre_delete_collection=True, - ) - graph.query("MATCH (n) DETACH DELETE n") - - graph.query("CREATE (:Test {name:'Foo', name2: 'Fooz'}),(:Test {name:'Bar'})") - - existing = Neo4jVector.from_existing_graph( - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - index_name="vector", - node_label="Test", - text_node_properties=["name", "name2"], - embedding_node_property="embedding", - ) - - output = existing.similarity_search("foo", k=1) - assert output == [Document(page_content="\nname: Foo\nname2: Fooz")] - - drop_vector_indexes(existing) - - -def test_neo4jvector_from_existing_graph_multiple_properties_hybrid() -> None: - """Test from_existing_graph with a two property.""" - graph = Neo4jVector.from_texts( - texts=["test"], - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - index_name="foo", - node_label="Foo", - embedding_node_property="vector", - text_node_property="info", - pre_delete_collection=True, - ) - graph.query("MATCH (n) DETACH DELETE n") - - graph.query("CREATE (:Test {name:'Foo', name2: 'Fooz'}),(:Test {name:'Bar'})") - - existing = Neo4jVector.from_existing_graph( - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - index_name="vector", - node_label="Test", - text_node_properties=["name", "name2"], - embedding_node_property="embedding", - search_type=SearchType.HYBRID, - ) - - output = existing.similarity_search("foo", k=1) - assert output == [Document(page_content="\nname: Foo\nname2: Fooz")] - - drop_vector_indexes(existing) - - -def test_neo4jvector_special_character() -> None: - """Test removing lucene.""" - text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts) - text_embedding_pairs = list(zip(texts, text_embeddings)) - docsearch = Neo4jVector.from_embeddings( - text_embeddings=text_embedding_pairs, - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - pre_delete_collection=True, - search_type=SearchType.HYBRID, - ) - output = docsearch.similarity_search( - "It is the end of the world. Take shelter!", k=1 - ) - assert output == [ - Document(page_content="It is the end of the world. Take shelter!", metadata={}) - ] - - drop_vector_indexes(docsearch) - - -def test_hybrid_score_normalization() -> None: - """Test if we can get two 1.0 documents with RRF""" - text_embeddings = FakeEmbeddingsWithOsDimension().embed_documents(texts) - text_embedding_pairs = list(zip(["foo"], text_embeddings)) - docsearch = Neo4jVector.from_embeddings( - text_embeddings=text_embedding_pairs, - embedding=FakeEmbeddingsWithOsDimension(), - url=url, - username=username, - password=password, - pre_delete_collection=True, - search_type=SearchType.HYBRID, - ) - # Remove deduplication part of the query - rrf_query = ( - _get_search_index_query(SearchType.HYBRID) - .rstrip("WITH node, max(score) AS score ORDER BY score DESC LIMIT $k") - .replace("UNION", "UNION ALL") - + "RETURN node.text AS text, score LIMIT 2" - ) - - output = docsearch.query( - rrf_query, - params={ - "index": "vector", - "k": 1, - "embedding": FakeEmbeddingsWithOsDimension().embed_query("foo"), - "query": "foo", - "keyword_index": "keyword", - }, - ) - # Both FT and Vector must return 1.0 score - assert output == [{"text": "foo", "score": 1.0}, {"text": "foo", "score": 1.0}] - drop_vector_indexes(docsearch) - - -def test_index_fetching() -> None: - """testing correct index creation and fetching""" - embeddings = FakeEmbeddings() - - def create_store( - node_label: str, index: str, text_properties: List[str] - ) -> Neo4jVector: - return Neo4jVector.from_existing_graph( - embedding=embeddings, - url=url, - username=username, - password=password, - index_name=index, - node_label=node_label, - text_node_properties=text_properties, - embedding_node_property="embedding", - ) - - def fetch_store(index_name: str) -> Neo4jVector: - store = Neo4jVector.from_existing_index( - embedding=embeddings, - url=url, - username=username, - password=password, - index_name=index_name, - ) - return store - - # create index 0 - index_0_str = "index0" - create_store("label0", index_0_str, ["text"]) - - # create index 1 - index_1_str = "index1" - create_store("label1", index_1_str, ["text"]) - - index_1_store = fetch_store(index_1_str) - assert index_1_store.index_name == index_1_str - - index_0_store = fetch_store(index_0_str) - assert index_0_store.index_name == index_0_str - drop_vector_indexes(index_1_store) - drop_vector_indexes(index_0_store) - - -def test_retrieval_params() -> None: - """Test if we use parameters in retrieval query""" - docsearch = Neo4jVector.from_texts( - texts=texts, - embedding=FakeEmbeddings(), - pre_delete_collection=True, - retrieval_query=""" - RETURN $test as text, score, {test: $test1} AS metadata - """, - ) - - output = docsearch.similarity_search( - "Foo", k=2, params={"test": "test", "test1": "test1"} - ) - assert output == [ - Document(page_content="test", metadata={"test": "test1"}), - Document(page_content="test", metadata={"test": "test1"}), - ] - drop_vector_indexes(docsearch) - - -def test_retrieval_dictionary() -> None: - """Test if we use parameters in retrieval query""" - docsearch = Neo4jVector.from_texts( - texts=texts, - embedding=FakeEmbeddings(), - pre_delete_collection=True, - retrieval_query=""" - RETURN { - name:'John', - age: 30, - skills: ["Python", "Data Analysis", "Machine Learning"]} as text, - score, {} AS metadata - """, - ) - expected_output = [ - Document( - page_content=( - "skills:\n- Python\n- Data Analysis\n- " - "Machine Learning\nage: 30\nname: John\n" - ) - ) - ] - - output = docsearch.similarity_search("Foo", k=1) - - def parse_document(doc: Document) -> Any: - return safe_load(doc.page_content) - - parsed_expected = [parse_document(doc) for doc in expected_output] - parsed_output = [parse_document(doc) for doc in output] - - assert parsed_output == parsed_expected - drop_vector_indexes(docsearch) - - -def test_metadata_filters_type1() -> None: - """Test metadata filters""" - docsearch = Neo4jVector.from_documents( - DOCUMENTS, - embedding=FakeEmbeddings(), - pre_delete_collection=True, - ) - # We don't test type 5, because LIKE has very SQL specific examples - for example in ( - TYPE_1_FILTERING_TEST_CASES - + TYPE_2_FILTERING_TEST_CASES - + TYPE_3_FILTERING_TEST_CASES - + TYPE_4_FILTERING_TEST_CASES - ): - filter_dict = cast(Dict[str, Any], example[0]) - output = docsearch.similarity_search("Foo", filter=filter_dict) - indices = cast(List[int], example[1]) - adjusted_indices = [index - 1 for index in indices] - expected_output = [DOCUMENTS[index] for index in adjusted_indices] - # We don't return id properties from similarity search by default - # Also remove any key where the value is None - for doc in expected_output: - if "id" in doc.metadata: - del doc.metadata["id"] - keys_with_none = [ - key for key, value in doc.metadata.items() if value is None - ] - for key in keys_with_none: - del doc.metadata[key] - - assert output == expected_output - drop_vector_indexes(docsearch) - - -def test_neo4jvector_relationship_index() -> None: - """Test end to end construction and search.""" - embeddings = FakeEmbeddingsWithOsDimension() - docsearch = Neo4jVector.from_texts( - texts=texts, - embedding=embeddings, - url=url, - username=username, - password=password, - pre_delete_collection=True, - ) - # Ingest data - docsearch.query( - ( - "CREATE ()-[:REL {text: 'foo', embedding: $e1}]->()" - ", ()-[:REL {text: 'far', embedding: $e2}]->()" - ), - params={ - "e1": embeddings.embed_query("foo"), - "e2": embeddings.embed_query("bar"), - }, - ) - # Create relationship index - docsearch.query( - """CREATE VECTOR INDEX `relationship` -FOR ()-[r:REL]-() ON (r.embedding) -OPTIONS {indexConfig: { - `vector.dimensions`: 1536, - `vector.similarity_function`: 'cosine' -}} -""" - ) - relationship_index = Neo4jVector.from_existing_relationship_index( - embeddings, index_name="relationship" - ) - - output = relationship_index.similarity_search("foo", k=1) - assert output == [Document(page_content="foo")] - - drop_vector_indexes(docsearch) - - -def test_neo4jvector_relationship_index_retrieval() -> None: - """Test end to end construction and search.""" - embeddings = FakeEmbeddingsWithOsDimension() - docsearch = Neo4jVector.from_texts( - texts=texts, - embedding=embeddings, - url=url, - username=username, - password=password, - pre_delete_collection=True, - ) - # Ingest data - docsearch.query( - ( - "CREATE ({node:'text'})-[:REL {text: 'foo', embedding: $e1}]->()" - ", ({node:'text'})-[:REL {text: 'far', embedding: $e2}]->()" - ), - params={ - "e1": embeddings.embed_query("foo"), - "e2": embeddings.embed_query("bar"), - }, - ) - # Create relationship index - docsearch.query( - """CREATE VECTOR INDEX `relationship` -FOR ()-[r:REL]-() ON (r.embedding) -OPTIONS {indexConfig: { - `vector.dimensions`: 1536, - `vector.similarity_function`: 'cosine' -}} -""" - ) - retrieval_query = ( - "RETURN relationship.text + '-' + startNode(relationship).node " - "AS text, score, {foo:'bar'} AS metadata" - ) - relationship_index = Neo4jVector.from_existing_relationship_index( - embeddings, index_name="relationship", retrieval_query=retrieval_query - ) - - output = relationship_index.similarity_search("foo", k=1) - assert output == [Document(page_content="foo-text", metadata={"foo": "bar"})] - - drop_vector_indexes(docsearch) - - -def test_neo4j_max_marginal_relevance_search() -> None: - """ - Test end to end construction and MMR search. - The embedding function used here ensures `texts` become - the following vectors on a circle (numbered v0 through v3): - - ______ v2 - / \ - / | v1 - v3 | . | query - | / v0 - |______/ (N.B. very crude drawing) - - With fetch_k==3 and k==2, when query is at (1, ), - one expects that v2 and v0 are returned (in some order). - """ - texts = ["-0.124", "+0.127", "+0.25", "+1.0"] - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = Neo4jVector.from_texts( - texts, - metadatas=metadatas, - embedding=AngularTwoDimensionalEmbeddings(), - pre_delete_collection=True, - ) - - expected_set = { - ("+0.25", 2), - ("-0.124", 0), - } - - output = docsearch.max_marginal_relevance_search("0.0", k=2, fetch_k=3) - output_set = { - (mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output - } - assert output_set == expected_set - - drop_vector_indexes(docsearch) - - -def test_neo4jvector_passing_graph_object() -> None: - """Test end to end construction and search with passing graph object.""" - graph = Neo4jGraph() - # Rewrite env vars to make sure it fails if env is used - os.environ["NEO4J_URI"] = "foo" - docsearch = Neo4jVector.from_texts( - texts=texts, - embedding=FakeEmbeddingsWithOsDimension(), - graph=graph, - pre_delete_collection=True, - ) - output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo")] - - drop_vector_indexes(docsearch) diff --git a/libs/community/tests/integration_tests/vectorstores/test_oraclevs.py b/libs/community/tests/integration_tests/vectorstores/test_oraclevs.py deleted file mode 100644 index f0ea54fb5..000000000 --- a/libs/community/tests/integration_tests/vectorstores/test_oraclevs.py +++ /dev/null @@ -1,955 +0,0 @@ -"""Test Oracle AI Vector Search functionality.""" - -# import required modules -import sys -import threading - -from langchain_community.embeddings import HuggingFaceEmbeddings -from langchain_community.vectorstores.oraclevs import ( - OracleVS, - _create_table, - _index_exists, - _table_exists, - create_index, - drop_index_if_exists, - drop_table_purge, -) -from langchain_community.vectorstores.utils import DistanceStrategy - -username = "" -password = "" -dsn = "" - - -############################ -####### table_exists ####### -############################ -def test_table_exists_test() -> None: - try: - import oracledb - except ImportError: - return - - try: - connection = oracledb.connect(user=username, password=password, dsn=dsn) - except Exception: - sys.exit(1) - # 1. Existing Table:(all capital letters) - # expectation:True - _table_exists(connection, "V$TRANSACTION") - - # 2. Existing Table:(all small letters) - # expectation:True - _table_exists(connection, "v$transaction") - - # 3. Non-Existing Table - # expectation:false - _table_exists(connection, "Hello") - - # 4. Invalid Table Name - # Expectation:ORA-00903: invalid table name - try: - _table_exists(connection, "123") - except Exception: - pass - - # 5. Empty String - # Expectation:ORA-00903: invalid table name - try: - _table_exists(connection, "") - except Exception: - pass - - # 6. Special Character - # Expectation:ORA-00911: #: invalid character after FROM - try: - _table_exists(connection, "##4") - except Exception: - pass - - # 7. Table name length > 128 - # Expectation:ORA-00972: The identifier XXXXXXXXXX...XXXXXXXXXX... - # exceeds the maximum length of 128 bytes. - try: - _table_exists(connection, "x" * 129) - except Exception: - pass - - # 8. - # Expectation:True - _create_table(connection, "TB1", 65535) - - # 9. Toggle Case (like TaBlE) - # Expectation:True - _table_exists(connection, "Tb1") - drop_table_purge(connection, "TB1") - - # 10. Table_Name→ "हिन्दी" - # Expectation:True - _create_table(connection, '"हिन्दी"', 545) - _table_exists(connection, '"हिन्दी"') - drop_table_purge(connection, '"हिन्दी"') - - -############################ -####### create_table ####### -############################ - - -def test_create_table_test() -> None: - try: - import oracledb - except ImportError: - return - - try: - connection = oracledb.connect(user=username, password=password, dsn=dsn) - except Exception: - sys.exit(1) - - # 1. New table - HELLO - # Dimension - 100 - # Expectation:table is created - _create_table(connection, "HELLO", 100) - - # 2. Existing table name - # HELLO - # Dimension - 110 - # Expectation:Nothing happens - _create_table(connection, "HELLO", 110) - drop_table_purge(connection, "HELLO") - - # 3. New Table - 123 - # Dimension - 100 - # Expectation:ORA-00903: invalid table name - try: - _create_table(connection, "123", 100) - drop_table_purge(connection, "123") - except Exception: - pass - - # 4. New Table - Hello123 - # Dimension - 65535 - # Expectation:table is created - _create_table(connection, "Hello123", 65535) - drop_table_purge(connection, "Hello123") - - # 5. New Table - T1 - # Dimension - 65536 - # Expectation:ORA-51801: VECTOR column type specification - # has an unsupported dimension count ('65536'). - try: - _create_table(connection, "T1", 65536) - drop_table_purge(connection, "T1") - except Exception: - pass - - # 6. New Table - T1 - # Dimension - 0 - # Expectation:ORA-51801: VECTOR column type specification has - # an unsupported dimension count (0). - try: - _create_table(connection, "T1", 0) - drop_table_purge(connection, "T1") - except Exception: - pass - - # 7. New Table - T1 - # Dimension - -1 - # Expectation:ORA-51801: VECTOR column type specification has - # an unsupported dimension count ('-'). - try: - _create_table(connection, "T1", -1) - drop_table_purge(connection, "T1") - except Exception: - pass - - # 8. New Table - T2 - # Dimension - '1000' - # Expectation:table is created - _create_table(connection, "T2", int("1000")) - drop_table_purge(connection, "T2") - - # 9. New Table - T3 - # Dimension - 100 passed as a variable - # Expectation:table is created - val = 100 - _create_table(connection, "T3", val) - drop_table_purge(connection, "T3") - - # 10. - # Expectation:ORA-00922: missing or invalid option - val2 = """H - ello""" - try: - _create_table(connection, val2, 545) - drop_table_purge(connection, val2) - except Exception: - pass - - # 11. New Table - हिन्दी - # Dimension - 545 - # Expectation:table is created - _create_table(connection, '"हिन्दी"', 545) - drop_table_purge(connection, '"हिन्दी"') - - # 12. - # Expectation:failure - user does not exist - try: - _create_table(connection, "U1.TB4", 128) - drop_table_purge(connection, "U1.TB4") - except Exception: - pass - - # 13. - # Expectation:table is created - _create_table(connection, '"T5"', 128) - drop_table_purge(connection, '"T5"') - - # 14. Toggle Case - # Expectation:table creation fails - try: - _create_table(connection, "TaBlE", 128) - drop_table_purge(connection, "TaBlE") - except Exception: - pass - - # 15. table_name as empty_string - # Expectation: ORA-00903: invalid table name - try: - _create_table(connection, "", 128) - drop_table_purge(connection, "") - _create_table(connection, '""', 128) - drop_table_purge(connection, '""') - except Exception: - pass - - # 16. Arithmetic Operations in dimension parameter - # Expectation:table is created - n = 1 - _create_table(connection, "T10", n + 500) - drop_table_purge(connection, "T10") - - # 17. String Operations in table_name&dimension parameter - # Expectation:table is created - _create_table(connection, "YaSh".replace("aS", "ok"), 500) - drop_table_purge(connection, "YaSh".replace("aS", "ok")) - - -################################## -####### create_hnsw_index ####### -################################## - - -def test_create_hnsw_index_test() -> None: - try: - import oracledb - except ImportError: - return - - try: - connection = oracledb.connect(user=username, password=password, dsn=dsn) - except Exception: - sys.exit(1) - # 1. Table_name - TB1 - # New Index - # distance_strategy - DistanceStrategy.Dot_product - # Expectation:Index created - model1 = HuggingFaceEmbeddings( - model_name="sentence-transformers/paraphrase-mpnet-base-v2" - ) - vs = OracleVS(connection, model1, "TB1", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index(connection, vs) - - # 2. Creating same index again - # Table_name - TB1 - # Expectation:Nothing happens - try: - create_index(connection, vs) - drop_index_if_exists(connection, "HNSW") - except Exception: - pass - drop_table_purge(connection, "TB1") - - # 3. Create index with following parameters: - # idx_name - hnsw_idx2 - # idx_type - HNSW - # Expectation:Index created - vs = OracleVS(connection, model1, "TB2", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index(connection, vs, params={"idx_name": "hnsw_idx2", "idx_type": "HNSW"}) - drop_index_if_exists(connection, "hnsw_idx2") - drop_table_purge(connection, "TB2") - - # 4. Table Name - TB1 - # idx_name - "हिन्दी" - # idx_type - HNSW - # Expectation:Index created - try: - vs = OracleVS(connection, model1, "TB3", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index(connection, vs, params={"idx_name": '"हिन्दी"', "idx_type": "HNSW"}) - drop_index_if_exists(connection, '"हिन्दी"') - except Exception: - pass - drop_table_purge(connection, "TB3") - - # 5. idx_name passed empty - # Expectation:ORA-01741: illegal zero-length identifier - try: - vs = OracleVS(connection, model1, "TB4", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index(connection, vs, params={"idx_name": '""', "idx_type": "HNSW"}) - drop_index_if_exists(connection, '""') - except Exception: - pass - drop_table_purge(connection, "TB4") - - # 6. idx_type left empty - # Expectation:Index created - try: - vs = OracleVS(connection, model1, "TB5", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index(connection, vs, params={"idx_name": "Hello", "idx_type": ""}) - drop_index_if_exists(connection, "Hello") - except Exception: - pass - drop_table_purge(connection, "TB5") - - # 7. efconstruction passed as parameter but not neighbours - # Expectation:Index created - vs = OracleVS(connection, model1, "TB7", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index( - connection, - vs, - params={"idx_name": "idx11", "efConstruction": 100, "idx_type": "HNSW"}, - ) - drop_index_if_exists(connection, "idx11") - drop_table_purge(connection, "TB7") - - # 8. efconstruction passed as parameter as well as neighbours - # (for this idx_type parameter is also necessary) - # Expectation:Index created - vs = OracleVS(connection, model1, "TB8", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index( - connection, - vs, - params={ - "idx_name": "idx11", - "efConstruction": 100, - "neighbors": 80, - "idx_type": "HNSW", - }, - ) - drop_index_if_exists(connection, "idx11") - drop_table_purge(connection, "TB8") - - # 9. Limit of Values for(integer values): - # parallel - # efConstruction - # Neighbors - # Accuracy - # 0 - # Expectation:Index created - vs = OracleVS(connection, model1, "TB15", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index( - connection, - vs, - params={ - "idx_name": "idx11", - "efConstruction": 200, - "neighbors": 100, - "idx_type": "HNSW", - "parallel": 8, - "accuracy": 10, - }, - ) - drop_index_if_exists(connection, "idx11") - drop_table_purge(connection, "TB15") - - # 11. index_name as - # Expectation:U1 not present - try: - vs = OracleVS( - connection, model1, "U1.TB16", DistanceStrategy.EUCLIDEAN_DISTANCE - ) - create_index( - connection, - vs, - params={ - "idx_name": "U1.idx11", - "efConstruction": 200, - "neighbors": 100, - "idx_type": "HNSW", - "parallel": 8, - "accuracy": 10, - }, - ) - drop_index_if_exists(connection, "U1.idx11") - drop_table_purge(connection, "TB16") - except Exception: - pass - - # 12. Index_name size >129 - # Expectation:Index not created - try: - vs = OracleVS(connection, model1, "TB17", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index(connection, vs, params={"idx_name": "x" * 129, "idx_type": "HNSW"}) - drop_index_if_exists(connection, "x" * 129) - except Exception: - pass - drop_table_purge(connection, "TB17") - - # 13. Index_name size 128 - # Expectation:Index created - vs = OracleVS(connection, model1, "TB18", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index(connection, vs, params={"idx_name": "x" * 128, "idx_type": "HNSW"}) - drop_index_if_exists(connection, "x" * 128) - drop_table_purge(connection, "TB18") - - -################################## -####### index_exists ############# -################################## - - -def test_index_exists_test() -> None: - try: - import oracledb - except ImportError: - return - - try: - connection = oracledb.connect(user=username, password=password, dsn=dsn) - except Exception: - sys.exit(1) - model1 = HuggingFaceEmbeddings( - model_name="sentence-transformers/paraphrase-mpnet-base-v2" - ) - # 1. Existing Index:(all capital letters) - # Expectation:true - vs = OracleVS(connection, model1, "TB1", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index(connection, vs, params={"idx_name": "idx11", "idx_type": "HNSW"}) - _index_exists(connection, "IDX11") - - # 2. Existing Table:(all small letters) - # Expectation:true - _index_exists(connection, "idx11") - - # 3. Non-Existing Index - # Expectation:False - _index_exists(connection, "Hello") - - # 4. Invalid Index Name - # Expectation:Error - try: - _index_exists(connection, "123") - except Exception: - pass - - # 5. Empty String - # Expectation:Error - try: - _index_exists(connection, "") - except Exception: - pass - try: - _index_exists(connection, "") - except Exception: - pass - - # 6. Special Character - # Expectation:Error - try: - _index_exists(connection, "##4") - except Exception: - pass - - # 7. Index name length > 128 - # Expectation:Error - try: - _index_exists(connection, "x" * 129) - except Exception: - pass - - # 8. - # Expectation:true - _index_exists(connection, "U1.IDX11") - - # 9. Toggle Case (like iDx11) - # Expectation:true - _index_exists(connection, "IdX11") - - # 10. Index_Name→ "हिन्दी" - # Expectation:true - drop_index_if_exists(connection, "idx11") - try: - create_index(connection, vs, params={"idx_name": '"हिन्दी"', "idx_type": "HNSW"}) - _index_exists(connection, '"हिन्दी"') - except Exception: - pass - drop_table_purge(connection, "TB1") - - -################################## -####### add_texts ################ -################################## - - -def test_add_texts_test() -> None: - try: - import oracledb - except ImportError: - return - - try: - connection = oracledb.connect(user=username, password=password, dsn=dsn) - except Exception: - sys.exit(1) - # 1. Add 2 records to table - # Expectation:Successful - texts = ["Rohan", "Shailendra"] - metadata = [ - {"id": "100", "link": "Document Example Test 1"}, - {"id": "101", "link": "Document Example Test 2"}, - ] - model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") - vs_obj = OracleVS(connection, model, "TB1", DistanceStrategy.EUCLIDEAN_DISTANCE) - vs_obj.add_texts(texts, metadata) - drop_table_purge(connection, "TB1") - - # 2. Add record but metadata is not there - # Expectation:An exception occurred :: Either specify an 'ids' list or - # 'metadatas' with an 'id' attribute for each element. - model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") - vs_obj = OracleVS(connection, model, "TB2", DistanceStrategy.EUCLIDEAN_DISTANCE) - texts2 = ["Sri Ram", "Krishna"] - vs_obj.add_texts(texts2) - drop_table_purge(connection, "TB2") - - # 3. Add record with ids option - # ids are passed as string - # ids are passed as empty string - # ids are passed as multi-line string - # ids are passed as "" - # Expectations: - # Successful - # Successful - # Successful - # Successful - - vs_obj = OracleVS(connection, model, "TB4", DistanceStrategy.EUCLIDEAN_DISTANCE) - ids3 = ["114", "124"] - vs_obj.add_texts(texts2, ids=ids3) - drop_table_purge(connection, "TB4") - - vs_obj = OracleVS(connection, model, "TB5", DistanceStrategy.EUCLIDEAN_DISTANCE) - ids4 = ["", "134"] - vs_obj.add_texts(texts2, ids=ids4) - drop_table_purge(connection, "TB5") - - vs_obj = OracleVS(connection, model, "TB6", DistanceStrategy.EUCLIDEAN_DISTANCE) - ids5 = [ - """Good afternoon - my friends""", - "India", - ] - vs_obj.add_texts(texts2, ids=ids5) - drop_table_purge(connection, "TB6") - - vs_obj = OracleVS(connection, model, "TB7", DistanceStrategy.EUCLIDEAN_DISTANCE) - ids6 = ['"Good afternoon"', '"India"'] - vs_obj.add_texts(texts2, ids=ids6) - drop_table_purge(connection, "TB7") - - # 4. Add records with ids and metadatas - # Expectation:Successful - vs_obj = OracleVS(connection, model, "TB8", DistanceStrategy.EUCLIDEAN_DISTANCE) - texts3 = ["Sri Ram 6", "Krishna 6"] - ids7 = ["1", "2"] - metadata = [ - {"id": "102", "link": "Document Example", "stream": "Science"}, - {"id": "104", "link": "Document Example 45"}, - ] - vs_obj.add_texts(texts3, metadata, ids=ids7) - drop_table_purge(connection, "TB8") - - # 5. Add 10000 records - # Expectation:Successful - vs_obj = OracleVS(connection, model, "TB9", DistanceStrategy.EUCLIDEAN_DISTANCE) - texts4 = ["Sri Ram{0}".format(i) for i in range(1, 10000)] - ids8 = ["Hello{0}".format(i) for i in range(1, 10000)] - vs_obj.add_texts(texts4, ids=ids8) - drop_table_purge(connection, "TB9") - - # 6. Add 2 different record concurrently - # Expectation:Successful - def add(val: str) -> None: - model = HuggingFaceEmbeddings( - model_name="sentence-transformers/all-mpnet-base-v2" - ) - vs_obj = OracleVS( - connection, model, "TB10", DistanceStrategy.EUCLIDEAN_DISTANCE - ) - texts5 = [val] - ids9 = texts5 - vs_obj.add_texts(texts5, ids=ids9) - - thread_1 = threading.Thread(target=add, args=("Sri Ram")) - thread_2 = threading.Thread(target=add, args=("Sri Krishna")) - thread_1.start() - thread_2.start() - thread_1.join() - thread_2.join() - drop_table_purge(connection, "TB10") - - # 7. Add 2 same record concurrently - # Expectation:Successful, For one of the insert,get primary key violation error - def add1(val: str) -> None: - model = HuggingFaceEmbeddings( - model_name="sentence-transformers/all-mpnet-base-v2" - ) - vs_obj = OracleVS( - connection, model, "TB11", DistanceStrategy.EUCLIDEAN_DISTANCE - ) - texts = [val] - ids10 = texts - vs_obj.add_texts(texts, ids=ids10) - - try: - thread_1 = threading.Thread(target=add1, args=("Sri Ram")) - thread_2 = threading.Thread(target=add1, args=("Sri Ram")) - thread_1.start() - thread_2.start() - thread_1.join() - thread_2.join() - except Exception: - pass - drop_table_purge(connection, "TB11") - - # 8. create object with table name of type - # Expectation:U1 does not exist - try: - vs_obj = OracleVS(connection, model, "U1.TB14", DistanceStrategy.DOT_PRODUCT) - for i in range(1, 10): - texts7 = ["Yash{0}".format(i)] - ids13 = ["1234{0}".format(i)] - vs_obj.add_texts(texts7, ids=ids13) - drop_table_purge(connection, "TB14") - except Exception: - pass - - -################################## -####### embed_documents(text) #### -################################## -def test_embed_documents_test() -> None: - try: - import oracledb - except ImportError: - return - - try: - connection = oracledb.connect(user=username, password=password, dsn=dsn) - except Exception: - sys.exit(1) - # 1. String Example-'Sri Ram' - # Expectation:Vector Printed - model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") - vs_obj = OracleVS(connection, model, "TB7", DistanceStrategy.EUCLIDEAN_DISTANCE) - - # 4. List - # Expectation:Vector Printed - vs_obj._embed_documents(["hello", "yash"]) - drop_table_purge(connection, "TB7") - - -################################## -####### embed_query(text) ######## -################################## -def test_embed_query_test() -> None: - try: - import oracledb - except ImportError: - return - - try: - connection = oracledb.connect(user=username, password=password, dsn=dsn) - except Exception: - sys.exit(1) - # 1. String - # Expectation:Vector printed - model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") - vs_obj = OracleVS(connection, model, "TB8", DistanceStrategy.EUCLIDEAN_DISTANCE) - vs_obj._embed_query("Sri Ram") - drop_table_purge(connection, "TB8") - - # 3. Empty string - # Expectation:[] - vs_obj._embed_query("") - - -################################## -####### create_index ############# -################################## -def test_create_index_test() -> None: - try: - import oracledb - except ImportError: - return - - try: - connection = oracledb.connect(user=username, password=password, dsn=dsn) - except Exception: - sys.exit(1) - # 1. No optional parameters passed - # Expectation:Successful - model1 = HuggingFaceEmbeddings( - model_name="sentence-transformers/paraphrase-mpnet-base-v2" - ) - vs = OracleVS(connection, model1, "TB1", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index(connection, vs) - drop_index_if_exists(connection, "HNSW") - drop_table_purge(connection, "TB1") - - # 2. ivf index - # Expectation:Successful - vs = OracleVS(connection, model1, "TB2", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index(connection, vs, {"idx_type": "IVF", "idx_name": "IVF"}) - drop_index_if_exists(connection, "IVF") - drop_table_purge(connection, "TB2") - - # 3. ivf index with neighbour_part passed as parameter - # Expectation:Successful - vs = OracleVS(connection, model1, "TB3", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index(connection, vs, {"idx_type": "IVF", "neighbor_part": 10}) - drop_index_if_exists(connection, "IVF") - drop_table_purge(connection, "TB3") - - # 4. ivf index with neighbour_part and accuracy passed as parameter - # Expectation:Successful - vs = OracleVS(connection, model1, "TB4", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index( - connection, vs, {"idx_type": "IVF", "neighbor_part": 10, "accuracy": 90} - ) - drop_index_if_exists(connection, "IVF") - drop_table_purge(connection, "TB4") - - # 5. ivf index with neighbour_part and parallel passed as parameter - # Expectation:Successful - vs = OracleVS(connection, model1, "TB5", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index( - connection, vs, {"idx_type": "IVF", "neighbor_part": 10, "parallel": 90} - ) - drop_index_if_exists(connection, "IVF") - drop_table_purge(connection, "TB5") - - # 6. ivf index and then perform dml(insert) - # Expectation:Successful - vs = OracleVS(connection, model1, "TB6", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index(connection, vs, {"idx_type": "IVF", "idx_name": "IVF"}) - texts = ["Sri Ram", "Krishna"] - vs.add_texts(texts) - # perform delete - vs.delete(["hello"]) - drop_index_if_exists(connection, "IVF") - drop_table_purge(connection, "TB6") - - # 7. ivf index with neighbour_part,parallel and accuracy passed as parameter - # Expectation:Successful - vs = OracleVS(connection, model1, "TB7", DistanceStrategy.EUCLIDEAN_DISTANCE) - create_index( - connection, - vs, - {"idx_type": "IVF", "neighbor_part": 10, "parallel": 90, "accuracy": 99}, - ) - drop_index_if_exists(connection, "IVF") - drop_table_purge(connection, "TB7") - - -################################## -####### perform_search ########### -################################## -def test_perform_search_test() -> None: - try: - import oracledb - except ImportError: - return - - try: - connection = oracledb.connect(user=username, password=password, dsn=dsn) - except Exception: - sys.exit(1) - model1 = HuggingFaceEmbeddings( - model_name="sentence-transformers/paraphrase-mpnet-base-v2" - ) - vs_1 = OracleVS(connection, model1, "TB10", DistanceStrategy.EUCLIDEAN_DISTANCE) - vs_2 = OracleVS(connection, model1, "TB11", DistanceStrategy.DOT_PRODUCT) - vs_3 = OracleVS(connection, model1, "TB12", DistanceStrategy.COSINE) - vs_4 = OracleVS(connection, model1, "TB13", DistanceStrategy.EUCLIDEAN_DISTANCE) - vs_5 = OracleVS(connection, model1, "TB14", DistanceStrategy.DOT_PRODUCT) - vs_6 = OracleVS(connection, model1, "TB15", DistanceStrategy.COSINE) - - # vector store lists: - vs_list = [vs_1, vs_2, vs_3, vs_4, vs_5, vs_6] - - for i, vs in enumerate(vs_list, start=1): - # insert data - texts = ["Yash", "Varanasi", "Yashaswi", "Mumbai", "BengaluruYash"] - metadatas = [ - {"id": "hello"}, - {"id": "105"}, - {"id": "106"}, - {"id": "yash"}, - {"id": "108"}, - ] - - vs.add_texts(texts, metadatas) - - # create index - if i == 1 or i == 2 or i == 3: - create_index(connection, vs, {"idx_type": "HNSW", "idx_name": f"IDX1{i}"}) - else: - create_index(connection, vs, {"idx_type": "IVF", "idx_name": f"IDX1{i}"}) - - # perform search - query = "YashB" - - filter = {"id": ["106", "108", "yash"]} - - # similarity_searh without filter - vs.similarity_search(query, 2) - - # similarity_searh with filter - vs.similarity_search(query, 2, filter=filter) - - # Similarity search with relevance score - vs.similarity_search_with_score(query, 2) - - # Similarity search with relevance score with filter - vs.similarity_search_with_score(query, 2, filter=filter) - - # Max marginal relevance search - vs.max_marginal_relevance_search(query, 2, fetch_k=20, lambda_mult=0.5) - - # Max marginal relevance search with filter - vs.max_marginal_relevance_search( - query, 2, fetch_k=20, lambda_mult=0.5, filter=filter - ) - - drop_table_purge(connection, "TB10") - drop_table_purge(connection, "TB11") - drop_table_purge(connection, "TB12") - drop_table_purge(connection, "TB13") - drop_table_purge(connection, "TB14") - drop_table_purge(connection, "TB15") diff --git a/libs/community/tests/integration_tests/vectorstores/test_pinecone.py b/libs/community/tests/integration_tests/vectorstores/test_pinecone.py deleted file mode 100644 index dad43f607..000000000 --- a/libs/community/tests/integration_tests/vectorstores/test_pinecone.py +++ /dev/null @@ -1,287 +0,0 @@ -import importlib -import os -import time -import uuid -from typing import TYPE_CHECKING, List - -import numpy as np -import pytest -from langchain_core.documents import Document - -from langchain_community.embeddings import OpenAIEmbeddings -from langchain_community.vectorstores.pinecone import Pinecone - -if TYPE_CHECKING: - import pinecone - -index_name = "langchain-test-index" # name of the index -namespace_name = "langchain-test-namespace" # name of the namespace -dimension = 1536 # dimension of the embeddings - - -def reset_pinecone() -> None: - assert os.environ.get("PINECONE_API_KEY") is not None - assert os.environ.get("PINECONE_ENVIRONMENT") is not None - - import pinecone - - importlib.reload(pinecone) - - pinecone.init( - api_key=os.environ.get("PINECONE_API_KEY"), - environment=os.environ.get("PINECONE_ENVIRONMENT"), - ) - - -class TestPinecone: - index: "pinecone.Index" - - @classmethod - def setup_class(cls) -> None: - import pinecone - - reset_pinecone() - - cls.index = pinecone.Index(index_name) - - if index_name in pinecone.list_indexes(): - index_stats = cls.index.describe_index_stats() - if index_stats["dimension"] == dimension: - # delete all the vectors in the index if the dimension is the same - # from all namespaces - index_stats = cls.index.describe_index_stats() - for _namespace_name in index_stats["namespaces"].keys(): - cls.index.delete(delete_all=True, namespace=_namespace_name) - - else: - pinecone.delete_index(index_name) - pinecone.create_index(name=index_name, dimension=dimension) - else: - pinecone.create_index(name=index_name, dimension=dimension) - - # ensure the index is empty - index_stats = cls.index.describe_index_stats() - assert index_stats["dimension"] == dimension - if index_stats["namespaces"].get(namespace_name) is not None: - assert index_stats["namespaces"][namespace_name]["vector_count"] == 0 - - @classmethod - def teardown_class(cls) -> None: - index_stats = cls.index.describe_index_stats() - for _namespace_name in index_stats["namespaces"].keys(): - cls.index.delete(delete_all=True, namespace=_namespace_name) - - reset_pinecone() - - @pytest.fixture(autouse=True) - def setup(self) -> None: - # delete all the vectors in the index - index_stats = self.index.describe_index_stats() - for _namespace_name in index_stats["namespaces"].keys(): - self.index.delete(delete_all=True, namespace=_namespace_name) - - reset_pinecone() - - @pytest.mark.vcr() - def test_from_texts( - self, texts: List[str], embedding_openai: OpenAIEmbeddings - ) -> None: - """Test end to end construction and search.""" - unique_id = uuid.uuid4().hex - needs = f"foobuu {unique_id} booo" - texts.insert(0, needs) - - docsearch = Pinecone.from_texts( - texts=texts, - embedding=embedding_openai, - index_name=index_name, - namespace=namespace_name, - ) - output = docsearch.similarity_search(unique_id, k=1, namespace=namespace_name) - assert output == [Document(page_content=needs)] - - @pytest.mark.vcr() - def test_from_texts_with_metadatas( - self, texts: List[str], embedding_openai: OpenAIEmbeddings - ) -> None: - """Test end to end construction and search.""" - - unique_id = uuid.uuid4().hex - needs = f"foobuu {unique_id} booo" - texts.insert(0, needs) - - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = Pinecone.from_texts( - texts, - embedding_openai, - index_name=index_name, - metadatas=metadatas, - namespace=namespace_name, - ) - output = docsearch.similarity_search(needs, k=1, namespace=namespace_name) - - # TODO: why metadata={"page": 0.0}) instead of {"page": 0}? - assert output == [Document(page_content=needs, metadata={"page": 0.0})] - - @pytest.mark.vcr() - def test_from_texts_with_scores(self, embedding_openai: OpenAIEmbeddings) -> None: - """Test end to end construction and search with scores and IDs.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = Pinecone.from_texts( - texts, - embedding_openai, - index_name=index_name, - metadatas=metadatas, - namespace=namespace_name, - ) - output = docsearch.similarity_search_with_score( - "foo", k=3, namespace=namespace_name - ) - docs = [o[0] for o in output] - scores = [o[1] for o in output] - sorted_documents = sorted(docs, key=lambda x: x.metadata["page"]) - - # TODO: why metadata={"page": 0.0}) instead of {"page": 0}, etc??? - assert sorted_documents == [ - Document(page_content="foo", metadata={"page": 0.0}), - Document(page_content="bar", metadata={"page": 1.0}), - Document(page_content="baz", metadata={"page": 2.0}), - ] - assert scores[0] > scores[1] > scores[2] - - def test_from_existing_index_with_namespaces( - self, embedding_openai: OpenAIEmbeddings - ) -> None: - """Test that namespaces are properly handled.""" - # Create two indexes with the same name but different namespaces - texts_1 = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts_1))] - Pinecone.from_texts( - texts_1, - embedding_openai, - index_name=index_name, - metadatas=metadatas, - namespace=f"{index_name}-1", - ) - - texts_2 = ["foo2", "bar2", "baz2"] - metadatas = [{"page": i} for i in range(len(texts_2))] - - Pinecone.from_texts( - texts_2, - embedding_openai, - index_name=index_name, - metadatas=metadatas, - namespace=f"{index_name}-2", - ) - - # Search with namespace - docsearch = Pinecone.from_existing_index( - index_name=index_name, - embedding=embedding_openai, - namespace=f"{index_name}-1", - ) - output = docsearch.similarity_search("foo", k=20, namespace=f"{index_name}-1") - # check that we don't get results from the other namespace - page_contents = sorted(set([o.page_content for o in output])) - assert all(content in ["foo", "bar", "baz"] for content in page_contents) - assert all(content not in ["foo2", "bar2", "baz2"] for content in page_contents) - - def test_add_documents_with_ids( - self, texts: List[str], embedding_openai: OpenAIEmbeddings - ) -> None: - ids = [uuid.uuid4().hex for _ in range(len(texts))] - Pinecone.from_texts( - texts=texts, - ids=ids, - embedding=embedding_openai, - index_name=index_name, - namespace=index_name, - ) - index_stats = self.index.describe_index_stats() - assert index_stats["namespaces"][index_name]["vector_count"] == len(texts) - - ids_1 = [uuid.uuid4().hex for _ in range(len(texts))] - Pinecone.from_texts( - texts=texts, - ids=ids_1, - embedding=embedding_openai, - index_name=index_name, - namespace=index_name, - ) - index_stats = self.index.describe_index_stats() - assert index_stats["namespaces"][index_name]["vector_count"] == len(texts) * 2 - assert index_stats["total_vector_count"] == len(texts) * 2 - - @pytest.mark.vcr() - def test_relevance_score_bound(self, embedding_openai: OpenAIEmbeddings) -> None: - """Ensures all relevance scores are between 0 and 1.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = Pinecone.from_texts( - texts, - embedding_openai, - index_name=index_name, - metadatas=metadatas, - ) - # wait for the index to be ready - time.sleep(20) - output = docsearch.similarity_search_with_relevance_scores("foo", k=3) - assert all( - (1 >= score or np.isclose(score, 1)) and score >= 0 for _, score in output - ) - - @pytest.mark.skipif(reason="slow to run for benchmark") - @pytest.mark.parametrize( - "pool_threads,batch_size,embeddings_chunk_size,data_multiplier", - [ - ( - 1, - 32, - 32, - 1000, - ), # simulate single threaded with embeddings_chunk_size = batch_size = 32 - ( - 1, - 32, - 1000, - 1000, - ), # simulate single threaded with embeddings_chunk_size = 1000 - ( - 4, - 32, - 1000, - 1000, - ), # simulate 4 threaded with embeddings_chunk_size = 1000 - (20, 64, 5000, 1000), - ], # simulate 20 threaded with embeddings_chunk_size = 5000 - ) - def test_from_texts_with_metadatas_benchmark( - self, - pool_threads: int, - batch_size: int, - embeddings_chunk_size: int, - data_multiplier: int, - documents: List[Document], - embedding_openai: OpenAIEmbeddings, - ) -> None: - """Test end to end construction and search.""" - - texts = [document.page_content for document in documents] * data_multiplier - uuids = [uuid.uuid4().hex for _ in range(len(texts))] - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = Pinecone.from_texts( - texts, - embedding_openai, - ids=uuids, - metadatas=metadatas, - index_name=index_name, - namespace=namespace_name, - pool_threads=pool_threads, - batch_size=batch_size, - embeddings_chunk_size=embeddings_chunk_size, - ) - - query = "What did the president say about Ketanji Brown Jackson" - _ = docsearch.similarity_search(query, k=1, namespace=namespace_name) diff --git a/libs/community/tests/integration_tests/vectorstores/test_vdms.py b/libs/community/tests/integration_tests/vectorstores/test_vdms.py deleted file mode 100644 index c389422ff..000000000 --- a/libs/community/tests/integration_tests/vectorstores/test_vdms.py +++ /dev/null @@ -1,375 +0,0 @@ -"""Test VDMS functionality.""" - -from __future__ import annotations - -import logging -import os -from typing import TYPE_CHECKING - -import pytest -from langchain_core.documents import Document - -from langchain_community.vectorstores import VDMS -from langchain_community.vectorstores.vdms import VDMS_Client, embedding2bytes -from tests.integration_tests.vectorstores.fake_embeddings import ( - ConsistentFakeEmbeddings, - FakeEmbeddings, -) - -if TYPE_CHECKING: - import vdms - -logging.basicConfig(level=logging.DEBUG) -embedding_function = FakeEmbeddings() - - -# The connection string matches the default settings in the docker-compose file -# located in the root of the repository: [root]/docker/docker-compose.yml -# To spin up a detached VDMS server: -# cd [root]/docker -# docker compose up -d vdms -@pytest.fixture -@pytest.mark.enable_socket -def vdms_client() -> vdms.vdms: - return VDMS_Client( - host=os.getenv("VDMS_DBHOST", "localhost"), - port=int(os.getenv("VDMS_DBPORT", 6025)), - ) - - -@pytest.mark.requires("vdms") -@pytest.mark.enable_socket -def test_init_from_client(vdms_client: vdms.vdms) -> None: - _ = VDMS( - embedding=embedding_function, - client=vdms_client, - ) - - -@pytest.mark.requires("vdms") -@pytest.mark.enable_socket -def test_from_texts_with_metadatas(vdms_client: vdms.vdms) -> None: - """Test end to end construction and search.""" - collection_name = "test_from_texts_with_metadatas" - texts = ["foo", "bar", "baz"] - ids = [f"test_from_texts_with_metadatas_{i}" for i in range(len(texts))] - metadatas = [{"page": str(i)} for i in range(1, len(texts) + 1)] - docsearch = VDMS.from_texts( - texts=texts, - ids=ids, - embedding=embedding_function, - metadatas=metadatas, - collection_name=collection_name, - client=vdms_client, - ) - output = docsearch.similarity_search("foo", k=1) - assert output == [ - Document(page_content="foo", metadata={"page": "1", "id": ids[0]}) - ] - - -@pytest.mark.requires("vdms") -@pytest.mark.enable_socket -def test_from_texts_with_metadatas_with_scores(vdms_client: vdms.vdms) -> None: - """Test end to end construction and scored search.""" - collection_name = "test_from_texts_with_metadatas_with_scores" - texts = ["foo", "bar", "baz"] - ids = [f"test_from_texts_with_metadatas_with_scores_{i}" for i in range(len(texts))] - metadatas = [{"page": str(i)} for i in range(1, len(texts) + 1)] - docsearch = VDMS.from_texts( - texts=texts, - ids=ids, - embedding=embedding_function, - metadatas=metadatas, - collection_name=collection_name, - client=vdms_client, - ) - output = docsearch.similarity_search_with_score("foo", k=1, fetch_k=1) - assert output == [ - (Document(page_content="foo", metadata={"page": "1", "id": ids[0]}), 0.0) - ] - - -@pytest.mark.requires("vdms") -@pytest.mark.enable_socket -def test_from_texts_with_metadatas_with_scores_using_vector( - vdms_client: vdms.vdms, -) -> None: - """Test end to end construction and scored search, using embedding vector.""" - collection_name = "test_from_texts_with_metadatas_with_scores_using_vector" - texts = ["foo", "bar", "baz"] - ids = [f"test_from_texts_with_metadatas_{i}" for i in range(len(texts))] - metadatas = [{"page": str(i)} for i in range(1, len(texts) + 1)] - docsearch = VDMS.from_texts( - texts=texts, - ids=ids, - embedding=embedding_function, - metadatas=metadatas, - collection_name=collection_name, - client=vdms_client, - ) - output = docsearch._similarity_search_with_relevance_scores("foo", k=1) - assert output == [ - (Document(page_content="foo", metadata={"page": "1", "id": ids[0]}), 0.0) - ] - - -@pytest.mark.requires("vdms") -@pytest.mark.enable_socket -def test_search_filter(vdms_client: vdms.vdms) -> None: - """Test end to end construction and search with metadata filtering.""" - collection_name = "test_search_filter" - texts = ["far", "bar", "baz"] - ids = [f"test_search_filter_{i}" for i in range(len(texts))] - metadatas = [{"first_letter": "{}".format(text[0])} for text in texts] - docsearch = VDMS.from_texts( - texts=texts, - ids=ids, - embedding=embedding_function, - metadatas=metadatas, - collection_name=collection_name, - client=vdms_client, - ) - output = docsearch.similarity_search( - "far", k=1, filter={"first_letter": ["==", "f"]} - ) - assert output == [ - Document(page_content="far", metadata={"first_letter": "f", "id": ids[0]}) - ] - output = docsearch.similarity_search( - "far", k=2, filter={"first_letter": ["==", "b"]} - ) - assert output == [ - Document(page_content="bar", metadata={"first_letter": "b", "id": ids[1]}), - Document(page_content="baz", metadata={"first_letter": "b", "id": ids[2]}), - ] - - -@pytest.mark.requires("vdms") -@pytest.mark.enable_socket -def test_search_filter_with_scores(vdms_client: vdms.vdms) -> None: - """Test end to end construction and scored search with metadata filtering.""" - collection_name = "test_search_filter_with_scores" - texts = ["far", "bar", "baz"] - ids = [f"test_search_filter_with_scores_{i}" for i in range(len(texts))] - metadatas = [{"first_letter": "{}".format(text[0])} for text in texts] - docsearch = VDMS.from_texts( - texts=texts, - ids=ids, - embedding=embedding_function, - metadatas=metadatas, - collection_name=collection_name, - client=vdms_client, - ) - output = docsearch.similarity_search_with_score( - "far", k=1, filter={"first_letter": ["==", "f"]} - ) - assert output == [ - ( - Document(page_content="far", metadata={"first_letter": "f", "id": ids[0]}), - 0.0, - ) - ] - - output = docsearch.similarity_search_with_score( - "far", k=2, filter={"first_letter": ["==", "b"]} - ) - assert output == [ - ( - Document(page_content="bar", metadata={"first_letter": "b", "id": ids[1]}), - 1.0, - ), - ( - Document(page_content="baz", metadata={"first_letter": "b", "id": ids[2]}), - 4.0, - ), - ] - - -@pytest.mark.requires("vdms") -@pytest.mark.enable_socket -def test_mmr(vdms_client: vdms.vdms) -> None: - """Test end to end construction and search.""" - collection_name = "test_mmr" - texts = ["foo", "bar", "baz"] - ids = [f"test_mmr_{i}" for i in range(len(texts))] - docsearch = VDMS.from_texts( - texts=texts, - ids=ids, - embedding=embedding_function, - collection_name=collection_name, - client=vdms_client, - ) - output = docsearch.max_marginal_relevance_search("foo", k=1) - assert output == [Document(page_content="foo", metadata={"id": ids[0]})] - - -@pytest.mark.requires("vdms") -@pytest.mark.enable_socket -def test_mmr_by_vector(vdms_client: vdms.vdms) -> None: - """Test end to end construction and search.""" - collection_name = "test_mmr_by_vector" - texts = ["foo", "bar", "baz"] - ids = [f"test_mmr_by_vector_{i}" for i in range(len(texts))] - docsearch = VDMS.from_texts( - texts=texts, - ids=ids, - embedding=embedding_function, - collection_name=collection_name, - client=vdms_client, - ) - embedded_query = embedding_function.embed_query("foo") - output = docsearch.max_marginal_relevance_search_by_vector(embedded_query, k=1) - assert output == [Document(page_content="foo", metadata={"id": ids[0]})] - - -@pytest.mark.requires("vdms") -@pytest.mark.enable_socket -def test_with_include_parameter(vdms_client: vdms.vdms) -> None: - """Test end to end construction and include parameter.""" - collection_name = "test_with_include_parameter" - texts = ["foo", "bar", "baz"] - docsearch = VDMS.from_texts( - texts=texts, - embedding=embedding_function, - collection_name=collection_name, - client=vdms_client, - ) - - response, response_array = docsearch.get(collection_name, include=["embeddings"]) - for emb in embedding_function.embed_documents(texts): - assert embedding2bytes(emb) in response_array - - response, response_array = docsearch.get(collection_name) - assert response_array == [] - - -@pytest.mark.requires("vdms") -@pytest.mark.enable_socket -def test_update_document(vdms_client: vdms.vdms) -> None: - """Test the update_document function in the VDMS class.""" - collection_name = "test_update_document" - - # Make a consistent embedding - const_embedding_function = ConsistentFakeEmbeddings() - - # Initial document content and id - initial_content = "foo" - document_id = "doc1" - - # Create an instance of Document with initial content and metadata - original_doc = Document(page_content=initial_content, metadata={"page": "1"}) - - # Initialize a VDMS instance with the original document - docsearch = VDMS.from_documents( - client=vdms_client, - collection_name=collection_name, - documents=[original_doc], - embedding=const_embedding_function, - ids=[document_id], - ) - old_response, old_embedding = docsearch.get( - collection_name, - constraints={"id": ["==", document_id]}, - include=["metadata", "embeddings"], - ) - # old_embedding = response_array[0] - - # Define updated content for the document - updated_content = "updated foo" - - # Create a new Document instance with the updated content and the same id - updated_doc = Document(page_content=updated_content, metadata={"page": "1"}) - - # Update the document in the VDMS instance - docsearch.update_document( - collection_name, document_id=document_id, document=updated_doc - ) - - # Perform a similarity search with the updated content - output = docsearch.similarity_search(updated_content, k=3)[0] - - # Assert that the updated document is returned by the search - assert output == Document( - page_content=updated_content, metadata={"page": "1", "id": document_id} - ) - - # Assert that the new embedding is correct - new_response, new_embedding = docsearch.get( - collection_name, - constraints={"id": ["==", document_id]}, - include=["metadata", "embeddings"], - ) - # new_embedding = response_array[0] - - assert new_embedding[0] == embedding2bytes( - const_embedding_function.embed_documents([updated_content])[0] - ) - assert new_embedding != old_embedding - - assert ( - new_response[0]["FindDescriptor"]["entities"][0]["content"] - != old_response[0]["FindDescriptor"]["entities"][0]["content"] - ) - - -@pytest.mark.requires("vdms") -@pytest.mark.enable_socket -def test_with_relevance_score(vdms_client: vdms.vdms) -> None: - """Test to make sure the relevance score is scaled to 0-1.""" - collection_name = "test_with_relevance_score" - texts = ["foo", "bar", "baz"] - ids = [f"test_relevance_scores_{i}" for i in range(len(texts))] - metadatas = [{"page": str(i)} for i in range(1, len(texts) + 1)] - docsearch = VDMS.from_texts( - texts=texts, - ids=ids, - embedding=embedding_function, - metadatas=metadatas, - collection_name=collection_name, - client=vdms_client, - ) - output = docsearch._similarity_search_with_relevance_scores("foo", k=3) - assert output == [ - (Document(page_content="foo", metadata={"page": "1", "id": ids[0]}), 0.0), - (Document(page_content="bar", metadata={"page": "2", "id": ids[1]}), 0.25), - (Document(page_content="baz", metadata={"page": "3", "id": ids[2]}), 1.0), - ] - - -@pytest.mark.requires("vdms") -@pytest.mark.enable_socket -def test_add_documents_no_metadata(vdms_client: vdms.vdms) -> None: - collection_name = "test_add_documents_no_metadata" - db = VDMS( - collection_name=collection_name, - embedding=embedding_function, - client=vdms_client, - ) - db.add_documents([Document(page_content="foo")]) - - -@pytest.mark.requires("vdms") -@pytest.mark.enable_socket -def test_add_documents_mixed_metadata(vdms_client: vdms.vdms) -> None: - collection_name = "test_add_documents_mixed_metadata" - db = VDMS( - collection_name=collection_name, - embedding=embedding_function, - client=vdms_client, - ) - - docs = [ - Document(page_content="foo"), - Document(page_content="bar", metadata={"baz": 1}), - ] - ids = ["10", "11"] - actual_ids = db.add_documents(docs, ids=ids) - assert actual_ids == ids - - search = db.similarity_search("foo bar", k=2) - docs[0].metadata = {"id": ids[0]} - docs[1].metadata["id"] = ids[1] - assert sorted(search, key=lambda d: d.page_content) == sorted( - docs, key=lambda d: d.page_content - ) diff --git a/libs/community/tests/integration_tests/vectorstores/test_weaviate.py b/libs/community/tests/integration_tests/vectorstores/test_weaviate.py deleted file mode 100644 index 906621850..000000000 --- a/libs/community/tests/integration_tests/vectorstores/test_weaviate.py +++ /dev/null @@ -1,248 +0,0 @@ -"""Test Weaviate functionality.""" - -import logging -import os -import uuid -from typing import Generator, Union - -import pytest -from langchain_core.documents import Document - -from langchain_community.embeddings.openai import OpenAIEmbeddings -from langchain_community.vectorstores.weaviate import Weaviate -from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings - -logging.basicConfig(level=logging.DEBUG) - -""" -cd tests/integration_tests/vectorstores/docker-compose -docker compose -f weaviate.yml up -""" - - -class TestWeaviate: - @classmethod - def setup_class(cls) -> None: - if not os.getenv("OPENAI_API_KEY"): - raise ValueError("OPENAI_API_KEY environment variable is not set") - - @pytest.fixture(scope="class", autouse=True) - def weaviate_url(self) -> Union[str, Generator[str, None, None]]: # type: ignore[return] - """Return the weaviate url.""" - from weaviate import Client - - url = "http://localhost:8080" - yield url - - # Clear the test index - client = Client(url) - client.schema.delete_all() - - @pytest.mark.vcr(ignore_localhost=True) - def test_similarity_search_without_metadata( - self, weaviate_url: str, embedding_openai: OpenAIEmbeddings - ) -> None: - """Test end to end construction and search without metadata.""" - texts = ["foo", "bar", "baz"] - docsearch = Weaviate.from_texts( - texts, - embedding_openai, - weaviate_url=weaviate_url, - ) - - output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo")] - - @pytest.mark.vcr(ignore_localhost=True) - def test_similarity_search_with_metadata( - self, weaviate_url: str, embedding_openai: OpenAIEmbeddings - ) -> None: - """Test end to end construction and search with metadata.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = Weaviate.from_texts( - texts, embedding_openai, metadatas=metadatas, weaviate_url=weaviate_url - ) - output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo", metadata={"page": 0})] - - @pytest.mark.vcr(ignore_localhost=True) - def test_similarity_search_with_metadata_and_filter( - self, weaviate_url: str, embedding_openai: OpenAIEmbeddings - ) -> None: - """Test end to end construction and search with metadata.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = Weaviate.from_texts( - texts, embedding_openai, metadatas=metadatas, weaviate_url=weaviate_url - ) - output = docsearch.similarity_search( - "foo", - k=2, - where_filter={"path": ["page"], "operator": "Equal", "valueNumber": 0}, - ) - assert output == [Document(page_content="foo", metadata={"page": 0})] - - @pytest.mark.vcr(ignore_localhost=True) - def test_similarity_search_with_metadata_and_additional( - self, weaviate_url: str, embedding_openai: OpenAIEmbeddings - ) -> None: - """Test end to end construction and search with metadata and additional.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = Weaviate.from_texts( - texts, embedding_openai, metadatas=metadatas, weaviate_url=weaviate_url - ) - output = docsearch.similarity_search( - "foo", - k=1, - additional=["certainty"], - ) - assert output == [ - Document( - page_content="foo", - metadata={"page": 0, "_additional": {"certainty": 1}}, - ) - ] - - @pytest.mark.vcr(ignore_localhost=True) - def test_similarity_search_with_uuids( - self, weaviate_url: str, embedding_openai: OpenAIEmbeddings - ) -> None: - """Test end to end construction and search with uuids.""" - texts = ["foo", "bar", "baz"] - # Weaviate replaces the object if the UUID already exists - uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, "same-name") for text in texts] - - metadatas = [{"page": i} for i in range(len(texts))] - docsearch = Weaviate.from_texts( - texts, - embedding_openai, - metadatas=metadatas, - weaviate_url=weaviate_url, - uuids=uuids, - ) - output = docsearch.similarity_search("foo", k=2) - assert len(output) == 1 - - @pytest.mark.vcr(ignore_localhost=True) - def test_max_marginal_relevance_search( - self, weaviate_url: str, embedding_openai: OpenAIEmbeddings - ) -> None: - """Test end to end construction and MRR search.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - - docsearch = Weaviate.from_texts( - texts, embedding_openai, metadatas=metadatas, weaviate_url=weaviate_url - ) - # if lambda=1 the algorithm should be equivalent to standard ranking - standard_ranking = docsearch.similarity_search("foo", k=2) - output = docsearch.max_marginal_relevance_search( - "foo", k=2, fetch_k=3, lambda_mult=1.0 - ) - assert output == standard_ranking - - # if lambda=0 the algorithm should favour maximal diversity - output = docsearch.max_marginal_relevance_search( - "foo", k=2, fetch_k=3, lambda_mult=0.0 - ) - assert output == [ - Document(page_content="foo", metadata={"page": 0}), - Document(page_content="bar", metadata={"page": 1}), - ] - - @pytest.mark.vcr(ignore_localhost=True) - def test_max_marginal_relevance_search_by_vector( - self, weaviate_url: str, embedding_openai: OpenAIEmbeddings - ) -> None: - """Test end to end construction and MRR search by vector.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - - docsearch = Weaviate.from_texts( - texts, embedding_openai, metadatas=metadatas, weaviate_url=weaviate_url - ) - foo_embedding = embedding_openai.embed_query("foo") - - # if lambda=1 the algorithm should be equivalent to standard ranking - standard_ranking = docsearch.similarity_search("foo", k=2) - output = docsearch.max_marginal_relevance_search_by_vector( - foo_embedding, k=2, fetch_k=3, lambda_mult=1.0 - ) - assert output == standard_ranking - - # if lambda=0 the algorithm should favour maximal diversity - output = docsearch.max_marginal_relevance_search_by_vector( - foo_embedding, k=2, fetch_k=3, lambda_mult=0.0 - ) - assert output == [ - Document(page_content="foo", metadata={"page": 0}), - Document(page_content="bar", metadata={"page": 1}), - ] - - @pytest.mark.vcr(ignore_localhost=True) - def test_max_marginal_relevance_search_with_filter( - self, weaviate_url: str, embedding_openai: OpenAIEmbeddings - ) -> None: - """Test end to end construction and MRR search.""" - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - - docsearch = Weaviate.from_texts( - texts, embedding_openai, metadatas=metadatas, weaviate_url=weaviate_url - ) - where_filter = {"path": ["page"], "operator": "Equal", "valueNumber": 0} - # if lambda=1 the algorithm should be equivalent to standard ranking - standard_ranking = docsearch.similarity_search( - "foo", k=2, where_filter=where_filter - ) - output = docsearch.max_marginal_relevance_search( - "foo", k=2, fetch_k=3, lambda_mult=1.0, where_filter=where_filter - ) - assert output == standard_ranking - - # if lambda=0 the algorithm should favour maximal diversity - output = docsearch.max_marginal_relevance_search( - "foo", k=2, fetch_k=3, lambda_mult=0.0, where_filter=where_filter - ) - assert output == [ - Document(page_content="foo", metadata={"page": 0}), - ] - - def test_add_texts_with_given_embedding(self, weaviate_url: str) -> None: - texts = ["foo", "bar", "baz"] - embedding = FakeEmbeddings() - - docsearch = Weaviate.from_texts( - texts, embedding=embedding, weaviate_url=weaviate_url - ) - - docsearch.add_texts(["foo"]) - output = docsearch.similarity_search_by_vector( - embedding.embed_query("foo"), k=2 - ) - assert output == [ - Document(page_content="foo"), - Document(page_content="foo"), - ] - - def test_add_texts_with_given_uuids(self, weaviate_url: str) -> None: - texts = ["foo", "bar", "baz"] - embedding = FakeEmbeddings() - uuids = [uuid.uuid5(uuid.NAMESPACE_DNS, text) for text in texts] - - docsearch = Weaviate.from_texts( - texts, - embedding=embedding, - weaviate_url=weaviate_url, - uuids=uuids, - ) - - # Weaviate replaces the object if the UUID already exists - docsearch.add_texts(["foo"], uuids=[uuids[0]]) - output = docsearch.similarity_search_by_vector( - embedding.embed_query("foo"), k=2 - ) - assert output[0] == Document(page_content="foo") - assert output[1] != Document(page_content="foo") diff --git a/libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py b/libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py index 8928ab87b..b856bbe82 100644 --- a/libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py +++ b/libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py @@ -23,7 +23,7 @@ ChainInput, SemanticContext, ) -from langchain_community.vectorstores.pinecone import Pinecone +from langchain_community.vectorstores.pgvector import PGVector from tests.unit_tests.llms.fake_llm import FakeLLM @@ -56,8 +56,8 @@ def retriever() -> FakeRetriever: """ retriever = FakeRetriever() retriever.search_kwargs = {} - # Set the class of vectorstore to Pinecone - retriever.vectorstore.__class__ = Pinecone + # Set the class of vectorstore to PGVector + retriever.vectorstore.__class__ = PGVector return retriever @@ -107,7 +107,7 @@ def test_validate_vectorstore(retriever: FakeRetriever) -> None: Test vectorstore validation """ - # No exception should be raised for supported vectorstores (Pinecone) + # No exception should be raised for supported vectorstores (PGVector) _ = PebbloRetrievalQA.from_chain_type( llm=FakeLLM(), chain_type="stuff", diff --git a/libs/community/tests/unit_tests/chat_loaders/test_imports.py b/libs/community/tests/unit_tests/chat_loaders/test_imports.py index dbe1af291..2ed9f201a 100644 --- a/libs/community/tests/unit_tests/chat_loaders/test_imports.py +++ b/libs/community/tests/unit_tests/chat_loaders/test_imports.py @@ -3,7 +3,6 @@ EXPECTED_ALL = [ "BaseChatLoader", "FolderFacebookMessengerChatLoader", - "GMailLoader", "IMessageChatLoader", "LangSmithDatasetChatLoader", "LangSmithRunChatLoader", diff --git a/libs/community/tests/unit_tests/chat_message_histories/test_imports.py b/libs/community/tests/unit_tests/chat_message_histories/test_imports.py index 4c14a0efd..e0121205c 100644 --- a/libs/community/tests/unit_tests/chat_message_histories/test_imports.py +++ b/libs/community/tests/unit_tests/chat_message_histories/test_imports.py @@ -1,7 +1,6 @@ from langchain_community.chat_message_histories import __all__, _module_lookup EXPECTED_ALL = [ - "AstraDBChatMessageHistory", "CassandraChatMessageHistory", "ChatMessageHistory", "CosmosDBChatMessageHistory", @@ -10,8 +9,6 @@ "FileChatMessageHistory", "FirestoreChatMessageHistory", "MomentoChatMessageHistory", - "MongoDBChatMessageHistory", - "Neo4jChatMessageHistory", "PostgresChatMessageHistory", "RedisChatMessageHistory", "RocksetChatMessageHistory", diff --git a/libs/community/tests/unit_tests/chat_models/test_anthropic.py b/libs/community/tests/unit_tests/chat_models/test_anthropic.py deleted file mode 100644 index 8e6a65974..000000000 --- a/libs/community/tests/unit_tests/chat_models/test_anthropic.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Test Anthropic Chat API wrapper.""" - -import os -from typing import List - -import pytest -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage - -from langchain_community.chat_models import ChatAnthropic -from langchain_community.chat_models.anthropic import ( - convert_messages_to_prompt_anthropic, -) - -os.environ["ANTHROPIC_API_KEY"] = "foo" - - -@pytest.mark.requires("anthropic") -def test_anthropic_model_name_param() -> None: - llm = ChatAnthropic(model_name="foo") - assert llm.model == "foo" - - -@pytest.mark.requires("anthropic") -def test_anthropic_model_param() -> None: - llm = ChatAnthropic(model="foo") # type: ignore[call-arg] - assert llm.model == "foo" - - -@pytest.mark.requires("anthropic") -def test_anthropic_model_kwargs() -> None: - llm = ChatAnthropic(model_kwargs={"foo": "bar"}) - assert llm.model_kwargs == {"foo": "bar"} - - -@pytest.mark.requires("anthropic") -def test_anthropic_fields_in_model_kwargs() -> None: - """Test that for backwards compatibility fields can be passed in as model_kwargs.""" - llm = ChatAnthropic(model_kwargs={"max_tokens_to_sample": 5}) - assert llm.max_tokens_to_sample == 5 - llm = ChatAnthropic(model_kwargs={"max_tokens": 5}) - assert llm.max_tokens_to_sample == 5 - - -@pytest.mark.requires("anthropic") -def test_anthropic_incorrect_field() -> None: - with pytest.warns(match="not default parameter"): - llm = ChatAnthropic(foo="bar") # type: ignore[call-arg] - assert llm.model_kwargs == {"foo": "bar"} - - -@pytest.mark.requires("anthropic") -def test_anthropic_initialization() -> None: - """Test anthropic initialization.""" - # Verify that chat anthropic can be initialized using a secret key provided - # as a parameter rather than an environment variable. - ChatAnthropic(model="test", anthropic_api_key="test") # type: ignore[arg-type, call-arg] - - -@pytest.mark.parametrize( - ("messages", "expected"), - [ - ([HumanMessage(content="Hello")], "\n\nHuman: Hello\n\nAssistant:"), - ( - [HumanMessage(content="Hello"), AIMessage(content="Answer:")], - "\n\nHuman: Hello\n\nAssistant: Answer:", - ), - ( - [ - SystemMessage(content="You're an assistant"), - HumanMessage(content="Hello"), - AIMessage(content="Answer:"), - ], - "You're an assistant\n\nHuman: Hello\n\nAssistant: Answer:", - ), - ], -) -def test_formatting(messages: List[BaseMessage], expected: str) -> None: - result = convert_messages_to_prompt_anthropic(messages) - assert result == expected diff --git a/libs/community/tests/unit_tests/chat_models/test_azureopenai.py b/libs/community/tests/unit_tests/chat_models/test_azureopenai.py deleted file mode 100644 index 49419ed7d..000000000 --- a/libs/community/tests/unit_tests/chat_models/test_azureopenai.py +++ /dev/null @@ -1,55 +0,0 @@ -import json -import os -from unittest import mock - -import pytest - -from langchain_community.chat_models.azure_openai import AzureChatOpenAI - - -@mock.patch.dict( - os.environ, - { - "OPENAI_API_KEY": "test", - "OPENAI_API_BASE": "https://oai.azure.com/", - "OPENAI_API_VERSION": "2023-05-01", - }, -) -@pytest.mark.requires("openai") -@pytest.mark.parametrize( - "model_name", ["gpt-4", "gpt-4-32k", "gpt-35-turbo", "gpt-35-turbo-16k"] -) -def test_model_name_set_on_chat_result_when_present_in_response( - model_name: str, -) -> None: - sample_response_text = f""" - {{ - "id": "chatcmpl-7ryweq7yc8463fas879t9hdkkdf", - "object": "chat.completion", - "created": 1690381189, - "model": "{model_name}", - "choices": [ - {{ - "index": 0, - "finish_reason": "stop", - "message": {{ - "role": "assistant", - "content": "I'm an AI assistant that can help you." - }} - }} - ], - "usage": {{ - "completion_tokens": 28, - "prompt_tokens": 15, - "total_tokens": 43 - }} - }} - """ - # convert sample_response_text to instance of Mapping[str, Any] - sample_response = json.loads(sample_response_text) - mock_chat = AzureChatOpenAI() - chat_result = mock_chat._create_chat_result(sample_response) - assert ( - chat_result.llm_output is not None - and chat_result.llm_output["model_name"] == model_name - ) diff --git a/libs/community/tests/unit_tests/chat_models/test_bedrock.py b/libs/community/tests/unit_tests/chat_models/test_bedrock.py deleted file mode 100644 index dc9de9537..000000000 --- a/libs/community/tests/unit_tests/chat_models/test_bedrock.py +++ /dev/null @@ -1,90 +0,0 @@ -"""Test Anthropic Chat API wrapper.""" - -from typing import List -from unittest.mock import MagicMock - -import pytest -from langchain_core.messages import ( - AIMessage, - BaseMessage, - HumanMessage, - SystemMessage, -) - -from langchain_community.chat_models import BedrockChat -from langchain_community.chat_models.meta import convert_messages_to_prompt_llama - - -@pytest.mark.parametrize( - ("messages", "expected"), - [ - ([HumanMessage(content="Hello")], "[INST] Hello [/INST]"), - ( - [HumanMessage(content="Hello"), AIMessage(content="Answer:")], - "[INST] Hello [/INST]\nAnswer:", - ), - ( - [ - SystemMessage(content="You're an assistant"), - HumanMessage(content="Hello"), - AIMessage(content="Answer:"), - ], - "<> You're an assistant <>\n[INST] Hello [/INST]\nAnswer:", - ), - ], -) -def test_formatting(messages: List[BaseMessage], expected: str) -> None: - result = convert_messages_to_prompt_llama(messages) - assert result == expected - - -@pytest.mark.parametrize( - "model_id", - ["anthropic.claude-v2", "amazon.titan-text-express-v1"], -) -def test_different_models_bedrock(model_id: str) -> None: - provider = model_id.split(".")[0] - client = MagicMock() - respbody = MagicMock() - if provider == "anthropic": - respbody.read.return_value = MagicMock( - decode=MagicMock(return_value=b'{"completion":"Hi back"}'), - ) - client.invoke_model.return_value = {"body": respbody} - elif provider == "amazon": - respbody.read.return_value = '{"results": [{"outputText": "Hi back"}]}' - client.invoke_model.return_value = {"body": respbody} - - model = BedrockChat(model_id=model_id, client=client) - - # should not throw an error - model.invoke("hello there") - - -def test_bedrock_combine_llm_output() -> None: - model_id = "anthropic.claude-3-haiku-20240307-v1:0" - client = MagicMock() - llm_outputs = [ - { - "model_id": "anthropic.claude-3-haiku-20240307-v1:0", - "usage": { - "completion_tokens": 1, - "prompt_tokens": 2, - "total_tokens": 3, - }, - }, - { - "model_id": "anthropic.claude-3-haiku-20240307-v1:0", - "usage": { - "completion_tokens": 1, - "prompt_tokens": 2, - "total_tokens": 3, - }, - }, - ] - model = BedrockChat(model_id=model_id, client=client) - final_output = model._combine_llm_outputs(llm_outputs) # type: ignore[arg-type] - assert final_output["model_id"] == model_id - assert final_output["usage"]["completion_tokens"] == 2 - assert final_output["usage"]["prompt_tokens"] == 4 - assert final_output["usage"]["total_tokens"] == 6 diff --git a/libs/community/tests/unit_tests/chat_models/test_cloudflare_workersai.py b/libs/community/tests/unit_tests/chat_models/test_cloudflare_workersai.py deleted file mode 100644 index 0f0fabe3c..000000000 --- a/libs/community/tests/unit_tests/chat_models/test_cloudflare_workersai.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Test CloudflareWorkersAI Chat API wrapper.""" - -from typing import Any, Dict, List, Type - -import pytest -from langchain_core.language_models import BaseChatModel -from langchain_core.messages import ( - AIMessage, - BaseMessage, - HumanMessage, - SystemMessage, - ToolMessage, -) -from langchain_tests.unit_tests import ChatModelUnitTests - -from langchain_community.chat_models.cloudflare_workersai import ( - ChatCloudflareWorkersAI, - _convert_messages_to_cloudflare_messages, -) - - -class TestChatCloudflareWorkersAI(ChatModelUnitTests): - @property - def chat_model_class(self) -> Type[BaseChatModel]: - return ChatCloudflareWorkersAI - - @property - def chat_model_params(self) -> dict: - return { - "account_id": "my_account_id", - "api_token": "my_api_token", - "model": "@hf/nousresearch/hermes-2-pro-mistral-7b", - } - - -@pytest.mark.parametrize( - ("messages", "expected"), - [ - # Test case with a single HumanMessage - ( - [HumanMessage(content="Hello, AI!")], - [{"role": "user", "content": "Hello, AI!"}], - ), - # Test case with SystemMessage, HumanMessage, and AIMessage without tool calls - ( - [ - SystemMessage(content="System initialized."), - HumanMessage(content="Hello, AI!"), - AIMessage(content="Response from AI"), - ], - [ - {"role": "system", "content": "System initialized."}, - {"role": "user", "content": "Hello, AI!"}, - {"role": "assistant", "content": "Response from AI"}, - ], - ), - # Test case with ToolMessage and tool_call_id - ( - [ - ToolMessage( - content="Tool message content", tool_call_id="tool_call_123" - ), - ], - [ - { - "role": "tool", - "content": "Tool message content", - "tool_call_id": "tool_call_123", - } - ], - ), - ], -) -def test_convert_messages_to_cloudflare_format( - messages: List[BaseMessage], expected: List[Dict[str, Any]] -) -> None: - result = _convert_messages_to_cloudflare_messages(messages) - assert result == expected diff --git a/libs/community/tests/unit_tests/chat_models/test_fireworks.py b/libs/community/tests/unit_tests/chat_models/test_fireworks.py deleted file mode 100644 index 61548fa52..000000000 --- a/libs/community/tests/unit_tests/chat_models/test_fireworks.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Test Fireworks chat model""" - -import sys - -import pytest -from pydantic import SecretStr -from pytest import CaptureFixture - -from langchain_community.chat_models import ChatFireworks - -if sys.version_info < (3, 9): - pytest.skip("fireworks-ai requires Python > 3.8", allow_module_level=True) - - -@pytest.mark.requires("fireworks") -def test_api_key_is_string() -> None: - llm = ChatFireworks(fireworks_api_key="secret-api-key") # type: ignore[arg-type] - assert isinstance(llm.fireworks_api_key, SecretStr) - - -@pytest.mark.requires("fireworks") -def test_api_key_masked_when_passed_via_constructor( - capsys: CaptureFixture, -) -> None: - llm = ChatFireworks(fireworks_api_key="secret-api-key") # type: ignore[arg-type] - print(llm.fireworks_api_key, end="") # noqa: T201 - captured = capsys.readouterr() - - assert captured.out == "**********" diff --git a/libs/community/tests/unit_tests/chat_models/test_huggingface.py b/libs/community/tests/unit_tests/chat_models/test_huggingface.py deleted file mode 100644 index 02b4739f5..000000000 --- a/libs/community/tests/unit_tests/chat_models/test_huggingface.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Test HuggingFace Chat wrapper.""" - -from importlib import import_module - - -def test_import_class() -> None: - """Test that the class can be imported.""" - module_name = "langchain_community.chat_models.huggingface" - class_name = "ChatHuggingFace" - - module = import_module(module_name) - assert hasattr(module, class_name) diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index 1c7dff198..5c2bd7440 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -1,30 +1,20 @@ from langchain_community.chat_models import __all__, _module_lookup EXPECTED_ALL = [ - "AzureChatOpenAI", - "BedrockChat", - "ChatAnthropic", "ChatAnyscale", "ChatBaichuan", "ChatClovaX", - "ChatCohere", "ChatCoze", - "ChatDatabricks", "ChatDeepInfra", "ChatEverlyAI", "ChatEdenAI", - "ChatFireworks", "ChatFriendli", "ChatGooglePalm", - "ChatHuggingFace", "ChatHunyuan", "ChatJavelinAIGateway", "ChatKinetica", "ChatKonko", - "ChatLiteLLM", - "ChatLiteLLMRouter", "ChatLlamaCpp", - "ChatMLflowAIGateway", "ChatMaritalk", "ChatMlflow", "ChatMLflowAIGateway", @@ -34,16 +24,10 @@ "ChatOCIModelDeployment", "ChatOCIModelDeploymentVLLM", "ChatOCIModelDeploymentTGI", - "ChatOllama", - "ChatOpenAI", "ChatOutlines", - "ChatPerplexity", "ChatPremAI", - "ChatSambaNovaCloud", - "ChatSambaStudio", "ChatSparkLLM", "ChatTongyi", - "ChatVertexAI", "ChatYandexGPT", "ChatYuan2", "ChatReka", @@ -51,7 +35,6 @@ "ErnieBotChat", "FakeListChatModel", "GPTRouter", - "GigaChat", "HumanInputChatModel", "JinaChat", "LlamaEdgeChatService", @@ -59,7 +42,6 @@ "MoonshotChat", "PaiEasChatEndpoint", "PromptLayerChatOpenAI", - "SolarChat", "QianfanChatEndpoint", "VolcEngineMaasChat", "ChatOctoAI", diff --git a/libs/community/tests/unit_tests/chat_models/test_litellm.py b/libs/community/tests/unit_tests/chat_models/test_litellm.py deleted file mode 100644 index 1d11fe5bd..000000000 --- a/libs/community/tests/unit_tests/chat_models/test_litellm.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Standard LangChain interface tests""" - -from typing import Type - -import pytest -from langchain_core.language_models import BaseChatModel -from langchain_tests.unit_tests import ChatModelUnitTests - -from langchain_community.chat_models.litellm import ChatLiteLLM - - -@pytest.mark.requires("litellm") -class TestLiteLLMStandard(ChatModelUnitTests): - @property - def chat_model_class(self) -> Type[BaseChatModel]: - return ChatLiteLLM - - @property - def chat_model_params(self) -> dict: - return {"api_key": "test_api_key"} - - @pytest.mark.xfail(reason="Not yet implemented.") - def test_standard_params(self, model: BaseChatModel) -> None: - super().test_standard_params(model) diff --git a/libs/community/tests/unit_tests/chat_models/test_octoai.py b/libs/community/tests/unit_tests/chat_models/test_octoai.py index fb639a739..412c724ab 100644 --- a/libs/community/tests/unit_tests/chat_models/test_octoai.py +++ b/libs/community/tests/unit_tests/chat_models/test_octoai.py @@ -9,25 +9,25 @@ @pytest.mark.requires("openai") def test__default_octoai_api_base() -> None: - chat = ChatOctoAI(octoai_api_token=SecretStr("test_token")) # type: ignore[call-arg] + chat = ChatOctoAI(octoai_api_token=SecretStr("test_token")) assert chat.octoai_api_base == DEFAULT_API_BASE @pytest.mark.requires("openai") def test__default_octoai_api_token() -> None: - chat = ChatOctoAI(octoai_api_token=SecretStr("test_token")) # type: ignore[call-arg] + chat = ChatOctoAI(octoai_api_token=SecretStr("test_token")) assert chat.octoai_api_token.get_secret_value() == "test_token" @pytest.mark.requires("openai") def test__default_model_name() -> None: - chat = ChatOctoAI(octoai_api_token=SecretStr("test_token")) # type: ignore[call-arg] + chat = ChatOctoAI(octoai_api_token=SecretStr("test_token")) assert chat.model_name == DEFAULT_MODEL @pytest.mark.requires("openai") def test__field_aliases() -> None: - chat = ChatOctoAI(octoai_api_token=SecretStr("test_token"), model="custom-model") # type: ignore[call-arg] + chat = ChatOctoAI(octoai_api_token=SecretStr("test_token"), model="custom-model") assert chat.model_name == "custom-model" assert chat.octoai_api_token.get_secret_value() == "test_token" @@ -41,7 +41,7 @@ def test__missing_octoai_api_token() -> None: @pytest.mark.requires("openai") def test__all_fields_provided() -> None: - chat = ChatOctoAI( # type: ignore[call-arg] + chat = ChatOctoAI( octoai_api_token=SecretStr("test_token"), model="custom-model", octoai_api_base="https://custom.api/base/", diff --git a/libs/community/tests/unit_tests/chat_models/test_ollama.py b/libs/community/tests/unit_tests/chat_models/test_ollama.py deleted file mode 100644 index 96075dda3..000000000 --- a/libs/community/tests/unit_tests/chat_models/test_ollama.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import List, Literal, Optional - -import pytest -from pydantic import BaseModel, ValidationError - -from langchain_community.chat_models import ChatOllama - - -def test_standard_params() -> None: - class ExpectedParams(BaseModel): - ls_provider: str - ls_model_name: str - ls_model_type: Literal["chat", "llm"] - ls_temperature: Optional[float] - ls_max_tokens: Optional[int] = None - ls_stop: Optional[List[str]] = None - - model = ChatOllama(model="llama3") - ls_params = model._get_ls_params() - try: - ExpectedParams(**ls_params) - except ValidationError as e: - pytest.fail(f"Validation error: {e}") - assert ls_params["ls_model_name"] == "llama3" - - # Test optional params - model = ChatOllama(num_predict=10, stop=["test"], temperature=0.33) - ls_params = model._get_ls_params() - try: - ExpectedParams(**ls_params) - except ValidationError as e: - pytest.fail(f"Validation error: {e}") - assert ls_params["ls_max_tokens"] == 10 - assert ls_params["ls_stop"] == ["test"] - assert ls_params["ls_temperature"] == 0.33 diff --git a/libs/community/tests/unit_tests/chat_models/test_openai.py b/libs/community/tests/unit_tests/chat_models/test_openai.py deleted file mode 100644 index f85638daa..000000000 --- a/libs/community/tests/unit_tests/chat_models/test_openai.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Test OpenAI Chat API wrapper.""" - -import json -from typing import Any, List -from unittest.mock import MagicMock, patch - -import pytest -from langchain_core.messages import ( - AIMessage, - FunctionMessage, - HumanMessage, - SystemMessage, -) - -from langchain_community.adapters.openai import convert_dict_to_message -from langchain_community.chat_models.openai import ChatOpenAI - - -@pytest.mark.requires("openai") -def test_openai_model_param() -> None: - test_cases: List[dict] = [ - {"model_name": "foo", "openai_api_key": "foo"}, - {"model": "foo", "openai_api_key": "foo"}, - {"model_name": "foo", "api_key": "foo"}, - {"model_name": "foo", "openai_api_key": "foo", "max_retries": 2}, - ] - - for case in test_cases: - llm = ChatOpenAI(**case) - assert llm.model_name == "foo", "Model name should be 'foo'" - assert llm.openai_api_key == "foo", "API key should be 'foo'" - assert hasattr(llm, "max_retries"), "max_retries attribute should exist" - assert llm.max_retries == 2, "max_retries default should be set to 2" - - -def test_function_message_dict_to_function_message() -> None: - content = json.dumps({"result": "Example #1"}) - name = "test_function" - result = convert_dict_to_message( - { - "role": "function", - "name": name, - "content": content, - } - ) - assert isinstance(result, FunctionMessage) - assert result.name == name - assert result.content == content - - -def test__convert_dict_to_message_human() -> None: - message = {"role": "user", "content": "foo"} - result = convert_dict_to_message(message) - expected_output = HumanMessage(content="foo") - assert result == expected_output - - -def test__convert_dict_to_message_ai() -> None: - message = {"role": "assistant", "content": "foo"} - result = convert_dict_to_message(message) - expected_output = AIMessage(content="foo") - assert result == expected_output - - -def test__convert_dict_to_message_system() -> None: - message = {"role": "system", "content": "foo"} - result = convert_dict_to_message(message) - expected_output = SystemMessage(content="foo") - assert result == expected_output - - -@pytest.fixture -def mock_completion() -> dict: - return { - "id": "chatcmpl-7fcZavknQda3SQ", - "object": "chat.completion", - "created": 1689989000, - "model": "gpt-3.5-turbo-0613", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "Bar Baz", - }, - "finish_reason": "stop", - } - ], - } - - -@pytest.mark.requires("openai") -def test_openai_predict(mock_completion: dict) -> None: - llm = ChatOpenAI(openai_api_key="foo") # type: ignore[call-arg] - mock_client = MagicMock() - completed = False - - def mock_create(*args: Any, **kwargs: Any) -> Any: - nonlocal completed - completed = True - return mock_completion - - mock_client.create = mock_create - with patch.object( - llm, - "client", - mock_client, - ): - res = llm.invoke("bar") - assert res.content == "Bar Baz" - assert completed - - -@pytest.mark.requires("openai") -async def test_openai_apredict(mock_completion: dict) -> None: - llm = ChatOpenAI(openai_api_key="foo") # type: ignore[call-arg] - mock_client = MagicMock() - completed = False - - async def mock_create(*args: Any, **kwargs: Any) -> Any: - nonlocal completed - completed = True - return mock_completion - - mock_client.create = mock_create - with patch.object( - llm, - "async_client", - mock_client, - ): - res = await llm.ainvoke("bar") - assert res.content == "Bar Baz" - assert completed diff --git a/libs/community/tests/unit_tests/chat_models/test_perplexity.py b/libs/community/tests/unit_tests/chat_models/test_perplexity.py deleted file mode 100644 index ebdb53067..000000000 --- a/libs/community/tests/unit_tests/chat_models/test_perplexity.py +++ /dev/null @@ -1,281 +0,0 @@ -"""Test Perplexity Chat API wrapper.""" - -import os -from typing import Any, Dict, List, Optional, Tuple, Type -from unittest.mock import MagicMock - -import pytest -from langchain_core.language_models import BaseChatModel -from langchain_core.messages import AIMessageChunk, BaseMessageChunk -from langchain_tests.unit_tests import ChatModelUnitTests -from pytest_mock import MockerFixture - -from langchain_community.chat_models import ChatPerplexity - -os.environ["PPLX_API_KEY"] = "foo" - - -@pytest.mark.requires("openai") -class TestPerplexityStandard(ChatModelUnitTests): - @property - def chat_model_class(self) -> Type[BaseChatModel]: - return ChatPerplexity - - @property - def init_from_env_params(self) -> Tuple[dict, dict, dict]: - return ( - {"PPLX_API_KEY": "api_key"}, - {}, - {"pplx_api_key": "api_key"}, - ) - - -@pytest.mark.requires("openai") -def test_perplexity_model_name_param() -> None: - llm = ChatPerplexity(model="foo") # type: ignore[call-arg] - assert llm.model == "foo" - - -@pytest.mark.requires("openai") -def test_perplexity_model_kwargs() -> None: - llm = ChatPerplexity(model="test", model_kwargs={"foo": "bar"}) # type: ignore[call-arg] - assert llm.model_kwargs == {"foo": "bar"} - - -@pytest.mark.requires("openai") -def test_perplexity_initialization() -> None: - """Test perplexity initialization.""" - # Verify that chat perplexity can be initialized using a secret key provided - # as a parameter rather than an environment variable. - for model in [ - ChatPerplexity( - model="test", timeout=1, api_key="test", temperature=0.7, verbose=True - ), - ChatPerplexity( # type: ignore[call-arg] - model="test", - request_timeout=1, - pplx_api_key="test", - temperature=0.7, - verbose=True, - ), - ]: - assert model.request_timeout == 1 - assert model.pplx_api_key == "test" - - -@pytest.mark.requires("openai") -def test_perplexity_stream_includes_citations(mocker: MockerFixture) -> None: - """Test that the stream method includes citations in the additional_kwargs.""" - llm = ChatPerplexity( - model="test", - timeout=30, - verbose=True, - ) - mock_chunk_0 = { - "choices": [ - { - "delta": { - "content": "Hello ", - }, - "finish_reason": None, - } - ], - "citations": ["example.com", "example2.com"], - } - mock_chunk_1 = { - "choices": [ - { - "delta": { - "content": "Perplexity", - }, - "finish_reason": None, - } - ], - "citations": ["example.com", "example2.com"], - } - mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1] - mock_stream = MagicMock() - mock_stream.__iter__.return_value = mock_chunks - patcher = mocker.patch.object( - llm.client.chat.completions, "create", return_value=mock_stream - ) - stream = llm.stream("Hello langchain") - full: Optional[BaseMessageChunk] = None - for i, chunk in enumerate(stream): - full = chunk if full is None else full + chunk - if chunk.chunk_position == "last": - continue - assert chunk.content == mock_chunks[i]["choices"][0]["delta"]["content"] - if i == 0: - assert chunk.additional_kwargs["citations"] == [ - "example.com", - "example2.com", - ] - else: - assert "citations" not in chunk.additional_kwargs - assert isinstance(full, AIMessageChunk) - assert full.content == "Hello Perplexity" - assert full.additional_kwargs == {"citations": ["example.com", "example2.com"]} - - patcher.assert_called_once() - - -@pytest.mark.requires("openai") -def test_perplexity_stream_includes_citations_and_images(mocker: MockerFixture) -> None: - """Test that the stream method includes citations in the additional_kwargs.""" - llm = ChatPerplexity( - model="test", - timeout=30, - verbose=True, - ) - mock_chunk_0 = { - "choices": [ - { - "delta": { - "content": "Hello ", - }, - "finish_reason": None, - } - ], - "citations": ["example.com", "example2.com"], - "images": [ - { - "image_url": "mock_image_url", - "origin_url": "mock_origin_url", - "height": 100, - "width": 100, - } - ], - } - mock_chunk_1 = { - "choices": [ - { - "delta": { - "content": "Perplexity", - }, - "finish_reason": None, - } - ], - "citations": ["example.com", "example2.com"], - "images": [ - { - "image_url": "mock_image_url", - "origin_url": "mock_origin_url", - "height": 100, - "width": 100, - } - ], - } - mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1] - mock_stream = MagicMock() - mock_stream.__iter__.return_value = mock_chunks - patcher = mocker.patch.object( - llm.client.chat.completions, "create", return_value=mock_stream - ) - stream = llm.stream("Hello langchain") - full: Optional[BaseMessageChunk] = None - for i, chunk in enumerate(stream): - full = chunk if full is None else full + chunk - if chunk.chunk_position == "last": - continue - assert chunk.content == mock_chunks[i]["choices"][0]["delta"]["content"] - if i == 0: - assert chunk.additional_kwargs["citations"] == [ - "example.com", - "example2.com", - ] - assert chunk.additional_kwargs["images"] == [ - { - "image_url": "mock_image_url", - "origin_url": "mock_origin_url", - "height": 100, - "width": 100, - } - ] - else: - assert "citations" not in chunk.additional_kwargs - assert "images" not in chunk.additional_kwargs - assert isinstance(full, AIMessageChunk) - assert full.content == "Hello Perplexity" - assert full.additional_kwargs == { - "citations": ["example.com", "example2.com"], - "images": [ - { - "image_url": "mock_image_url", - "origin_url": "mock_origin_url", - "height": 100, - "width": 100, - } - ], - } - - patcher.assert_called_once() - - -@pytest.mark.requires("openai") -def test_perplexity_stream_includes_citations_and_related_questions( - mocker: MockerFixture, -) -> None: - """Test that the stream method includes citations in the additional_kwargs.""" - llm = ChatPerplexity( - model="test", - timeout=30, - verbose=True, - ) - mock_chunk_0 = { - "choices": [ - { - "delta": { - "content": "Hello ", - }, - "finish_reason": None, - } - ], - "citations": ["example.com", "example2.com"], - "related_questions": ["example_question_1", "example_question_2"], - } - mock_chunk_1 = { - "choices": [ - { - "delta": { - "content": "Perplexity", - }, - "finish_reason": None, - } - ], - "citations": ["example.com", "example2.com"], - "related_questions": ["example_question_1", "example_question_2"], - } - mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1] - mock_stream = MagicMock() - mock_stream.__iter__.return_value = mock_chunks - patcher = mocker.patch.object( - llm.client.chat.completions, "create", return_value=mock_stream - ) - stream = llm.stream("Hello langchain") - full: Optional[BaseMessageChunk] = None - for i, chunk in enumerate(stream): - full = chunk if full is None else full + chunk - if chunk.chunk_position == "last": - continue - assert chunk.content == mock_chunks[i]["choices"][0]["delta"]["content"] - if i == 0: - assert chunk.additional_kwargs["citations"] == [ - "example.com", - "example2.com", - ] - assert chunk.additional_kwargs["related_questions"] == [ - "example_question_1", - "example_question_2", - ] - else: - assert "citations" not in chunk.additional_kwargs - assert "related_questions" not in chunk.additional_kwargs - assert isinstance(full, AIMessageChunk) - assert full.content == "Hello Perplexity" - assert full.additional_kwargs == { - "citations": ["example.com", "example2.com"], - "related_questions": ["example_question_1", "example_question_2"], - } - - patcher.assert_called_once() diff --git a/libs/community/tests/unit_tests/document_loaders/test_web_base.py b/libs/community/tests/unit_tests/document_loaders/test_web_base.py index 950542303..7c1be50b6 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_web_base.py +++ b/libs/community/tests/unit_tests/document_loaders/test_web_base.py @@ -74,7 +74,7 @@ async def mock_text() -> str: mock_response.text = mock_text mock_get.return_value.__aenter__.return_value = mock_response - loader = WebBaseLoader(web_paths=["https://www.example.com"]) + loader = WebBaseLoader(web_paths=["https://www.example.com"], show_progress=False) results = [] async for result in loader.alazy_load(): results.append(result) @@ -100,6 +100,7 @@ async def mock_text_bs4() -> str: loader = WebBaseLoader( web_paths=["https://www.example.com"], bs_kwargs={"parse_only": bs4.SoupStrainer(class_="special-class")}, + show_progress=False, ) results = [] async for result in loader.alazy_load(): @@ -121,6 +122,7 @@ async def mock_text() -> str: loader = WebBaseLoader( web_paths=["https://www.example.com"], header_template={"User-Agent": "test-user-agent"}, + show_progress=False, ) results = loader.aload() assert len(results) == 1 diff --git a/libs/community/tests/unit_tests/embeddings/test_openai.py b/libs/community/tests/unit_tests/embeddings/test_openai.py index 8d349e673..8330d808f 100644 --- a/libs/community/tests/unit_tests/embeddings/test_openai.py +++ b/libs/community/tests/unit_tests/embeddings/test_openai.py @@ -23,16 +23,39 @@ def test_openai_incorrect_field() -> None: @pytest.mark.requires("openai") def test_embed_documents_with_custom_chunk_size() -> None: - embeddings = OpenAIEmbeddings(chunk_size=2) - texts = ["text1", "text2", "text3", "text4"] - custom_chunk_size = 3 + with ( + patch("openai.OpenAI") as mock_openai_class, + patch("tiktoken.encoding_for_model") as mock_tiktoken, + ): + mock_client = mock_openai_class.return_value + mock_embeddings_client = mock_client.embeddings + + # Mock tiktoken encoding + mock_encoding = mock_tiktoken.return_value + mock_encoding.encode.side_effect = [ + [1342, 19], + [1342, 19], + [1342, 19], + [1342, 19], + ] + + embeddings = OpenAIEmbeddings(chunk_size=2) + texts = ["text1", "text2", "text3", "text4"] + custom_chunk_size = 3 - with patch.object(embeddings.client, "create") as mock_create: - mock_create.side_effect = [ + mock_embeddings_client.create.side_effect = [ {"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]}, {"data": [{"embedding": [0.5, 0.6]}, {"embedding": [0.7, 0.8]}]}, ] embeddings.embed_documents(texts, chunk_size=custom_chunk_size) - mock_create.assert_any_call(input=[[1342, 19]], **embeddings._invocation_params) - mock_create.assert_any_call(input=[[1342, 19]], **embeddings._invocation_params) + + # Verify the expected token inputs - this should be called twice + # with chunk_size=3 + assert mock_embeddings_client.create.call_count == 2 + mock_embeddings_client.create.assert_any_call( + input=[[1342, 19]], **embeddings._invocation_params + ) + mock_embeddings_client.create.assert_any_call( + input=[[1342, 19]], **embeddings._invocation_params + ) diff --git a/libs/community/tests/unit_tests/imports/test_langchain_proxy_imports.py b/libs/community/tests/unit_tests/imports/test_langchain_proxy_imports.py index 6d34cfcd5..21235b8e1 100644 --- a/libs/community/tests/unit_tests/imports/test_langchain_proxy_imports.py +++ b/libs/community/tests/unit_tests/imports/test_langchain_proxy_imports.py @@ -22,5 +22,14 @@ def test_vectorstores() -> None: "MyScaleSettings", "AzureCosmosDBVectorSearch", "Tigris", # Was removed upstream but haven't released yet + "Weaviate", # Removed from vectorstores module + "Qdrant", # Removed from vectorstores module + "Pinecone", # Removed from vectorstores module + "Neo4jVector", # Removed from vectorstores module + "MongoDBAtlasVectorSearch", # Removed from vectorstores module + "Milvus", # Removed from vectorstores module + "MatchingEngine", # Removed from vectorstores module + "DeepLake", # Removed from vectorstores module + "DatabricksVectorSearch", # Removed from vectorstores module ]: assert issubclass(getattr(vectorstores, cls), VectorStore) diff --git a/libs/community/tests/unit_tests/retrievers/test_imports.py b/libs/community/tests/unit_tests/retrievers/test_imports.py index dde08e2f8..0e2a04c54 100644 --- a/libs/community/tests/unit_tests/retrievers/test_imports.py +++ b/libs/community/tests/unit_tests/retrievers/test_imports.py @@ -30,7 +30,6 @@ "OutlineRetriever", "PineconeHybridSearchRetriever", "PubMedRetriever", - "QdrantSparseVectorRetriever", "RemoteLangChainRetriever", "RememberizerRetriever", "SVMRetriever", diff --git a/libs/community/tests/unit_tests/test_dependencies.py b/libs/community/tests/unit_tests/test_dependencies.py index 1ce576f81..323488d04 100644 --- a/libs/community/tests/unit_tests/test_dependencies.py +++ b/libs/community/tests/unit_tests/test_dependencies.py @@ -91,7 +91,7 @@ def test_imports() -> None: from langchain_core.prompts import BasePromptTemplate # noqa: F401 from langchain_community.callbacks import OpenAICallbackHandler # noqa: F401 - from langchain_community.chat_models import ChatOpenAI # noqa: F401 + from langchain_community.chat_models.openai import ChatOpenAI # noqa: F401 from langchain_community.document_loaders import BSHTMLLoader # noqa: F401 from langchain_community.embeddings import OpenAIEmbeddings # noqa: F401 from langchain_community.llms import OpenAI # noqa: F401 diff --git a/libs/community/tests/unit_tests/tools/test_imports.py b/libs/community/tests/unit_tests/tools/test_imports.py index e93d224c2..6d45194b9 100644 --- a/libs/community/tests/unit_tests/tools/test_imports.py +++ b/libs/community/tests/unit_tests/tools/test_imports.py @@ -64,9 +64,6 @@ "GmailSendMessage", "GoogleBooksQueryRun", "GoogleCloudTextToSpeechTool", - "GooglePlacesTool", - "GoogleSearchResults", - "GoogleSearchRun", "GoogleSerperResults", "GoogleSerperRun", "HumanInputRun", diff --git a/libs/community/tests/unit_tests/utilities/test_imports.py b/libs/community/tests/unit_tests/utilities/test_imports.py index 7d2a008bb..4d4fa5d7d 100644 --- a/libs/community/tests/unit_tests/utilities/test_imports.py +++ b/libs/community/tests/unit_tests/utilities/test_imports.py @@ -17,9 +17,7 @@ "GoogleFinanceAPIWrapper", "GoogleJobsAPIWrapper", "GoogleLensAPIWrapper", - "GooglePlacesAPIWrapper", "GoogleScholarAPIWrapper", - "GoogleSearchAPIWrapper", "GoogleSerperAPIWrapper", "GoogleTrendsAPIWrapper", "GraphQLAPIWrapper", @@ -36,7 +34,6 @@ "NVIDIARivaTTS", "NVIDIARivaStream", "OpenWeatherMapAPIWrapper", - "OracleSummary", "OutlineAPIWrapper", "NutritionAIAPI", "Portkey", diff --git a/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py b/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py deleted file mode 100644 index 57fd47411..000000000 --- a/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py +++ /dev/null @@ -1,825 +0,0 @@ -import itertools -import random -import uuid -from typing import Dict, List, Optional, Set -from unittest.mock import MagicMock, patch - -import pytest - -from langchain_community.vectorstores import DatabricksVectorSearch -from tests.integration_tests.vectorstores.fake_embeddings import ( - FakeEmbeddings, - fake_texts, -) - -DEFAULT_VECTOR_DIMENSION = 4 - - -class FakeEmbeddingsWithDimension(FakeEmbeddings): - """Fake embeddings functionality for testing.""" - - def __init__(self, dimension: int = DEFAULT_VECTOR_DIMENSION): - super().__init__() - self.dimension = dimension - - def embed_documents(self, embedding_texts: List[str]) -> List[List[float]]: - """Return simple embeddings.""" - return [ - [float(1.0)] * (self.dimension - 1) + [float(i)] - for i in range(len(embedding_texts)) - ] - - def embed_query(self, text: str) -> List[float]: - """Return simple embeddings.""" - return [float(1.0)] * (self.dimension - 1) + [float(0.0)] - - -DEFAULT_EMBEDDING_MODEL = FakeEmbeddingsWithDimension() -DEFAULT_TEXT_COLUMN = "text" -DEFAULT_VECTOR_COLUMN = "text_vector" -DEFAULT_PRIMARY_KEY = "id" - -DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS = { - "name": "ml.llm.index", - "endpoint_name": "vector_search_endpoint", - "index_type": "DELTA_SYNC", - "primary_key": DEFAULT_PRIMARY_KEY, - "delta_sync_index_spec": { - "source_table": "ml.llm.source_table", - "pipeline_type": "CONTINUOUS", - "embedding_source_columns": [ - { - "name": DEFAULT_TEXT_COLUMN, - "embedding_model_endpoint_name": "openai-text-embedding", - } - ], - }, -} - -DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS = { - "name": "ml.llm.index", - "endpoint_name": "vector_search_endpoint", - "index_type": "DELTA_SYNC", - "primary_key": DEFAULT_PRIMARY_KEY, - "delta_sync_index_spec": { - "source_table": "ml.llm.source_table", - "pipeline_type": "CONTINUOUS", - "embedding_vector_columns": [ - { - "name": DEFAULT_VECTOR_COLUMN, - "embedding_dimension": DEFAULT_VECTOR_DIMENSION, - } - ], - }, -} - -DIRECT_ACCESS_INDEX = { - "name": "ml.llm.index", - "endpoint_name": "vector_search_endpoint", - "index_type": "DIRECT_ACCESS", - "primary_key": DEFAULT_PRIMARY_KEY, - "direct_access_index_spec": { - "embedding_vector_columns": [ - { - "name": DEFAULT_VECTOR_COLUMN, - "embedding_dimension": DEFAULT_VECTOR_DIMENSION, - } - ], - "schema_json": f"{{" - f'"{DEFAULT_PRIMARY_KEY}": "int", ' - f'"feat1": "str", ' - f'"feat2": "float", ' - f'"text": "string", ' - f'"{DEFAULT_VECTOR_COLUMN}": "array"' - f"}}", - }, -} - -ALL_INDEXES = [ - DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS, - DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, - DIRECT_ACCESS_INDEX, -] - -EXAMPLE_SEARCH_RESPONSE = { - "manifest": { - "column_count": 3, - "columns": [ - {"name": DEFAULT_PRIMARY_KEY}, - {"name": DEFAULT_TEXT_COLUMN}, - {"name": "score"}, - ], - }, - "result": { - "row_count": len(fake_texts), - "data_array": sorted( - [[str(uuid.uuid4()), s, random.uniform(0, 1)] for s in fake_texts], - key=lambda x: x[2], # type: ignore[arg-type,return-value] - reverse=True, - ), - }, - "next_page_token": "", -} - -EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE: Dict = { - "manifest": { - "column_count": 3, - "columns": [ - {"name": DEFAULT_PRIMARY_KEY}, - {"name": DEFAULT_TEXT_COLUMN}, - {"name": "score"}, - ], - }, - "result": { - "row_count": len(fake_texts), - "data_array": sorted( - [[str(uuid.uuid4()), s, 0.5] for s in fake_texts], - key=lambda x: x[2], - reverse=True, - ), - }, - "next_page_token": "", -} - -EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING = { - "manifest": { - "column_count": 3, - "columns": [ - {"name": DEFAULT_PRIMARY_KEY}, - {"name": DEFAULT_TEXT_COLUMN}, - {"name": DEFAULT_VECTOR_COLUMN}, - {"name": "score"}, - ], - }, - "result": { - "row_count": len(fake_texts), - "data_array": sorted( - [ - [str(uuid.uuid4()), s, e, random.uniform(0, 1)] - for s, e in zip( - fake_texts, DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts) - ) - ], - key=lambda x: x[2], # type: ignore[arg-type,return-value] - reverse=True, - ), - }, - "next_page_token": "", -} - -ALL_QUERY_TYPES = [ - None, - "ANN", - "HYBRID", -] - - -def mock_index(index_details: dict) -> MagicMock: - from databricks.vector_search.client import VectorSearchIndex - - index = MagicMock(spec=VectorSearchIndex) - index.describe.return_value = index_details - return index - - -def default_databricks_vector_search( - index: MagicMock, columns: Optional[List[str]] = None -) -> DatabricksVectorSearch: - return DatabricksVectorSearch( - index, - embedding=DEFAULT_EMBEDDING_MODEL, - text_column=DEFAULT_TEXT_COLUMN, - columns=columns, - ) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -def test_init_delta_sync_with_managed_embeddings() -> None: - index = mock_index(DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS) - vectorsearch = DatabricksVectorSearch(index) - assert vectorsearch.index == index - - -@pytest.mark.requires("databricks", "databricks.vector_search") -def test_init_delta_sync_with_self_managed_embeddings() -> None: - index = mock_index(DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS) - vectorsearch = DatabricksVectorSearch( - index, - embedding=DEFAULT_EMBEDDING_MODEL, - text_column=DEFAULT_TEXT_COLUMN, - ) - assert vectorsearch.index == index - - -@pytest.mark.requires("databricks", "databricks.vector_search") -def test_init_direct_access_index() -> None: - index = mock_index(DIRECT_ACCESS_INDEX) - vectorsearch = DatabricksVectorSearch( - index, - embedding=DEFAULT_EMBEDDING_MODEL, - text_column=DEFAULT_TEXT_COLUMN, - ) - assert vectorsearch.index == index - - -@pytest.mark.requires("databricks", "databricks.vector_search") -def test_init_fail_no_index() -> None: - with pytest.raises(TypeError): - DatabricksVectorSearch() # type: ignore[call-arg] - - -@pytest.mark.requires("databricks", "databricks.vector_search") -def test_init_fail_index_none() -> None: - with pytest.raises(TypeError) as ex: - DatabricksVectorSearch(None) - assert "index must be of type VectorSearchIndex." in str(ex.value) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -def test_init_fail_text_column_mismatch() -> None: - index = mock_index(DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS) - with pytest.raises(ValueError) as ex: - DatabricksVectorSearch( - index, - text_column="some_other_column", - ) - assert ( - f"text_column 'some_other_column' does not match with the source column of the " - f"index: '{DEFAULT_TEXT_COLUMN}'." in str(ex.value) - ) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -@pytest.mark.parametrize( - "index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX] -) -def test_init_fail_no_text_column(index_details: dict) -> None: - index = mock_index(index_details) - with pytest.raises(ValueError) as ex: - DatabricksVectorSearch( - index, - embedding=DEFAULT_EMBEDDING_MODEL, - ) - assert "`text_column` is required for this index." in str(ex.value) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -@pytest.mark.parametrize("index_details", [DIRECT_ACCESS_INDEX]) -def test_init_fail_columns_not_in_schema(index_details: dict) -> None: - index = mock_index(index_details) - with pytest.raises(ValueError) as ex: - DatabricksVectorSearch( - index, - embedding=DEFAULT_EMBEDDING_MODEL, - text_column=DEFAULT_TEXT_COLUMN, - columns=["some_random_column"], - ) - assert "column 'some_random_column' is not in the index's schema." in str(ex.value) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -@pytest.mark.parametrize( - "index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX] -) -def test_init_fail_no_embedding(index_details: dict) -> None: - index = mock_index(index_details) - with pytest.raises(ValueError) as ex: - DatabricksVectorSearch( - index, - text_column=DEFAULT_TEXT_COLUMN, - ) - assert "`embedding` is required for this index." in str(ex.value) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -@pytest.mark.parametrize( - "index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX] -) -def test_init_fail_embedding_dim_mismatch(index_details: dict) -> None: - index = mock_index(index_details) - with pytest.raises(ValueError) as ex: - DatabricksVectorSearch( - index, - text_column=DEFAULT_TEXT_COLUMN, - embedding=FakeEmbeddingsWithDimension(DEFAULT_VECTOR_DIMENSION + 1), - ) - assert ( - f"embedding model's dimension '{DEFAULT_VECTOR_DIMENSION + 1}' does not match " - f"with the index's dimension '{DEFAULT_VECTOR_DIMENSION}'" - ) in str(ex.value) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -def test_from_texts_not_supported() -> None: - with pytest.raises(NotImplementedError) as ex: - DatabricksVectorSearch.from_texts(fake_texts, FakeEmbeddings()) - assert ( - "`from_texts` is not supported. " - "Use `add_texts` to add to existing direct-access index." - ) in str(ex.value) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -@pytest.mark.parametrize( - "index_details", - [DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS, DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS], -) -def test_add_texts_not_supported_for_delta_sync_index(index_details: dict) -> None: - index = mock_index(index_details) - vectorsearch = default_databricks_vector_search(index) - with pytest.raises(ValueError) as ex: - vectorsearch.add_texts(fake_texts) - assert "`add_texts` is only supported for direct-access index." in str(ex.value) - - -def is_valid_uuid(val: str) -> bool: - try: - uuid.UUID(str(val)) - return True - except ValueError: - return False - - -@pytest.mark.requires("databricks", "databricks.vector_search") -def test_add_texts() -> None: - index = mock_index(DIRECT_ACCESS_INDEX) - vectorsearch = DatabricksVectorSearch( - index, - embedding=DEFAULT_EMBEDDING_MODEL, - text_column=DEFAULT_TEXT_COLUMN, - ) - ids = [idx for idx, i in enumerate(fake_texts)] - vectors = DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts) - - added_ids = vectorsearch.add_texts(fake_texts, ids=ids) - index.upsert.assert_called_once_with( - [ - { - DEFAULT_PRIMARY_KEY: id_, - DEFAULT_TEXT_COLUMN: text, - DEFAULT_VECTOR_COLUMN: vector, - } - for text, vector, id_ in zip(fake_texts, vectors, ids) - ] - ) - assert len(added_ids) == len(fake_texts) - assert added_ids == ids - - -@pytest.mark.requires("databricks", "databricks.vector_search") -def test_add_texts_handle_single_text() -> None: - index = mock_index(DIRECT_ACCESS_INDEX) - vectorsearch = DatabricksVectorSearch( - index, - embedding=DEFAULT_EMBEDDING_MODEL, - text_column=DEFAULT_TEXT_COLUMN, - ) - vectors = DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts) - - added_ids = vectorsearch.add_texts(fake_texts[0]) - index.upsert.assert_called_once_with( - [ - { - DEFAULT_PRIMARY_KEY: id_, - DEFAULT_TEXT_COLUMN: text, - DEFAULT_VECTOR_COLUMN: vector, - } - for text, vector, id_ in zip(fake_texts, vectors, added_ids) - ] - ) - assert len(added_ids) == 1 - assert is_valid_uuid(added_ids[0]) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -def test_add_texts_with_default_id() -> None: - index = mock_index(DIRECT_ACCESS_INDEX) - vectorsearch = default_databricks_vector_search(index) - vectors = DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts) - - added_ids = vectorsearch.add_texts(fake_texts) - index.upsert.assert_called_once_with( - [ - { - DEFAULT_PRIMARY_KEY: id_, - DEFAULT_TEXT_COLUMN: text, - DEFAULT_VECTOR_COLUMN: vector, - } - for text, vector, id_ in zip(fake_texts, vectors, added_ids) - ] - ) - assert len(added_ids) == len(fake_texts) - assert all([is_valid_uuid(id_) for id_ in added_ids]) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -def test_add_texts_with_metadata() -> None: - index = mock_index(DIRECT_ACCESS_INDEX) - vectorsearch = default_databricks_vector_search(index) - vectors = DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts) - metadatas = [{"feat1": str(i), "feat2": i + 1000} for i in range(len(fake_texts))] - - added_ids = vectorsearch.add_texts(fake_texts, metadatas=metadatas) - index.upsert.assert_called_once_with( - [ - { - DEFAULT_PRIMARY_KEY: id_, - DEFAULT_TEXT_COLUMN: text, - DEFAULT_VECTOR_COLUMN: vector, - **metadata, - } - for text, vector, id_, metadata in zip( - fake_texts, vectors, added_ids, metadatas - ) - ] - ) - assert len(added_ids) == len(fake_texts) - assert all([is_valid_uuid(id_) for id_ in added_ids]) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -@pytest.mark.parametrize( - "index_details", - [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX], -) -def test_embeddings_property(index_details: dict) -> None: - index = mock_index(index_details) - vectorsearch = default_databricks_vector_search(index) - assert vectorsearch.embeddings == DEFAULT_EMBEDDING_MODEL - - -@pytest.mark.requires("databricks", "databricks.vector_search") -@pytest.mark.parametrize( - "index_details", - [DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS, DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS], -) -def test_delete_not_supported_for_delta_sync_index(index_details: dict) -> None: - index = mock_index(index_details) - vectorsearch = default_databricks_vector_search(index) - with pytest.raises(ValueError) as ex: - vectorsearch.delete(["some id"]) - assert "`delete` is only supported for direct-access index." in str(ex.value) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -def test_delete() -> None: - index = mock_index(DIRECT_ACCESS_INDEX) - vectorsearch = default_databricks_vector_search(index) - - vectorsearch.delete(["some id"]) - index.delete.assert_called_once_with(["some id"]) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -def test_delete_fail_no_ids() -> None: - index = mock_index(DIRECT_ACCESS_INDEX) - vectorsearch = default_databricks_vector_search(index) - - with pytest.raises(ValueError) as ex: - vectorsearch.delete() - assert "ids must be provided." in str(ex.value) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -@pytest.mark.parametrize( - "index_details, query_type", itertools.product(ALL_INDEXES, [None, "ANN"]) -) -def test_similarity_search(index_details: dict, query_type: Optional[str]) -> None: - index = mock_index(index_details) - index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE - vectorsearch = default_databricks_vector_search(index) - query = "foo" - filters = {"some filter": True} - limit = 7 - - search_result = vectorsearch.similarity_search( - query, k=limit, filter=filters, query_type=query_type - ) - if index_details == DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS: - index.similarity_search.assert_called_once_with( - columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN], - query_text=query, - query_vector=None, - filters=filters, - num_results=limit, - query_type=query_type, - ) - else: - index.similarity_search.assert_called_once_with( - columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN], - query_text=None, - query_vector=DEFAULT_EMBEDDING_MODEL.embed_query(query), - filters=filters, - num_results=limit, - query_type=query_type, - ) - assert len(search_result) == len(fake_texts) - assert sorted([d.page_content for d in search_result]) == sorted(fake_texts) - assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result]) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -@pytest.mark.parametrize("index_details", ALL_INDEXES) -def test_similarity_search_hybrid(index_details: dict) -> None: - index = mock_index(index_details) - index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE - vectorsearch = default_databricks_vector_search(index) - query = "foo" - filters = {"some filter": True} - limit = 7 - - search_result = vectorsearch.similarity_search( - query, k=limit, filter=filters, query_type="HYBRID" - ) - if index_details == DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS: - index.similarity_search.assert_called_once_with( - columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN], - query_text=query, - query_vector=None, - filters=filters, - num_results=limit, - query_type="HYBRID", - ) - else: - index.similarity_search.assert_called_once_with( - columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN], - query_text=query, - query_vector=DEFAULT_EMBEDDING_MODEL.embed_query(query), - filters=filters, - num_results=limit, - query_type="HYBRID", - ) - assert len(search_result) == len(fake_texts) - assert sorted([d.page_content for d in search_result]) == sorted(fake_texts) - assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result]) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -def test_similarity_search_both_filter_and_filters_passed() -> None: - index = mock_index(DIRECT_ACCESS_INDEX) - index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE - vectorsearch = default_databricks_vector_search(index) - query = "foo" - filter = {"some filter": True} - filters = {"some other filter": False} - - vectorsearch.similarity_search(query, filter=filter, filters=filters) - index.similarity_search.assert_called_once_with( - columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN], - query_vector=DEFAULT_EMBEDDING_MODEL.embed_query(query), - # `filter` should prevail over `filters` - filters=filter, - num_results=4, - query_text=None, - query_type=None, - ) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -@pytest.mark.parametrize( - "index_details, columns, expected_columns", - [ - (DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, None, {"id"}), - (DIRECT_ACCESS_INDEX, None, {"id"}), - ( - DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, - [DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN, DEFAULT_VECTOR_COLUMN], - {"text_vector", "id"}, - ), - ( - DIRECT_ACCESS_INDEX, - [DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN, DEFAULT_VECTOR_COLUMN], - {"text_vector", "id"}, - ), - ], -) -def test_mmr_search( - index_details: dict, columns: Optional[List[str]], expected_columns: Set[str] -) -> None: - index = mock_index(index_details) - index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING - vectorsearch = default_databricks_vector_search(index, columns) - query = fake_texts[0] - filters = {"some filter": True} - limit = 1 - - search_result = vectorsearch.max_marginal_relevance_search( - query, k=limit, filters=filters - ) - assert [doc.page_content for doc in search_result] == [fake_texts[0]] - assert [set(doc.metadata.keys()) for doc in search_result] == [expected_columns] - - -@pytest.mark.requires("databricks", "databricks.vector_search") -@pytest.mark.parametrize( - "index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX] -) -def test_mmr_parameters(index_details: dict) -> None: - index = mock_index(index_details) - index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_WITH_EMBEDDING - query = fake_texts[0] - limit = 1 - fetch_k = 3 - lambda_mult = 0.25 - filters = {"some filter": True} - - with patch( - "langchain_community.vectorstores.databricks_vector_search.maximal_marginal_relevance" - ) as mock_mmr: - mock_mmr.return_value = [2] - retriever = default_databricks_vector_search(index).as_retriever( - search_type="mmr", - search_kwargs={ - "k": limit, - "fetch_k": fetch_k, - "lambda_mult": lambda_mult, - "filter": filters, - }, - ) - search_result = retriever.invoke(query) - - mock_mmr.assert_called_once() - assert mock_mmr.call_args[1]["lambda_mult"] == lambda_mult - assert index.similarity_search.call_args[1]["num_results"] == fetch_k - assert index.similarity_search.call_args[1]["filters"] == filters - assert len(search_result) == limit - - -@pytest.mark.requires("databricks", "databricks.vector_search") -@pytest.mark.parametrize( - "index_details, threshold", itertools.product(ALL_INDEXES, [0.4, 0.5, 0.8]) -) -def test_similarity_score_threshold(index_details: dict, threshold: float) -> None: - index = mock_index(index_details) - index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE - uniform_response_score = EXAMPLE_SEARCH_RESPONSE_FIXED_SCORE["result"][ - "data_array" - ][0][2] - query = fake_texts[0] - limit = len(fake_texts) - - retriever = default_databricks_vector_search(index).as_retriever( - search_type="similarity_score_threshold", - search_kwargs={"k": limit, "score_threshold": threshold}, - ) - search_result = retriever.invoke(query) - if uniform_response_score >= threshold: - assert len(search_result) == len(fake_texts) - else: - assert len(search_result) == 0 - - -@pytest.mark.requires("databricks", "databricks.vector_search") -def test_standard_params() -> None: - index = mock_index(DIRECT_ACCESS_INDEX) - vectorstore = default_databricks_vector_search(index) - retriever = vectorstore.as_retriever() - ls_params = retriever._get_ls_params() - assert ls_params == { - "ls_retriever_name": "vectorstore", - "ls_vector_store_provider": "DatabricksVectorSearch", - "ls_embedding_provider": "FakeEmbeddingsWithDimension", - } - - index = mock_index(DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS) - vectorstore = default_databricks_vector_search(index) - retriever = vectorstore.as_retriever() - ls_params = retriever._get_ls_params() - assert ls_params == { - "ls_retriever_name": "vectorstore", - "ls_vector_store_provider": "DatabricksVectorSearch", - } - - -@pytest.mark.requires("databricks", "databricks.vector_search") -@pytest.mark.parametrize( - "index_details, query_type", - itertools.product( - [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX], [None, "ANN"] - ), -) -def test_similarity_search_by_vector( - index_details: dict, query_type: Optional[str] -) -> None: - index = mock_index(index_details) - index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE - vectorsearch = default_databricks_vector_search(index) - query_embedding = DEFAULT_EMBEDDING_MODEL.embed_query("foo") - filters = {"some filter": True} - limit = 7 - - search_result = vectorsearch.similarity_search_by_vector( - query_embedding, k=limit, filter=filters, query_type=query_type - ) - index.similarity_search.assert_called_once_with( - columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN], - query_vector=query_embedding, - filters=filters, - num_results=limit, - query_type=query_type, - query_text=None, - ) - assert len(search_result) == len(fake_texts) - assert sorted([d.page_content for d in search_result]) == sorted(fake_texts) - assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result]) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -@pytest.mark.parametrize( - "index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX] -) -def test_similarity_search_by_vector_hybrid(index_details: dict) -> None: - index = mock_index(index_details) - index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE - vectorsearch = default_databricks_vector_search(index) - query_embedding = DEFAULT_EMBEDDING_MODEL.embed_query("foo") - filters = {"some filter": True} - limit = 7 - - search_result = vectorsearch.similarity_search_by_vector( - query_embedding, k=limit, filter=filters, query_type="HYBRID", query="foo" - ) - index.similarity_search.assert_called_once_with( - columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN], - query_vector=query_embedding, - filters=filters, - num_results=limit, - query_type="HYBRID", - query_text="foo", - ) - assert len(search_result) == len(fake_texts) - assert sorted([d.page_content for d in search_result]) == sorted(fake_texts) - assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result]) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -@pytest.mark.parametrize("index_details", ALL_INDEXES) -def test_similarity_search_empty_result(index_details: dict) -> None: - index = mock_index(index_details) - index.similarity_search.return_value = { - "manifest": { - "column_count": 3, - "columns": [ - {"name": DEFAULT_PRIMARY_KEY}, - {"name": DEFAULT_TEXT_COLUMN}, - {"name": "score"}, - ], - }, - "result": { - "row_count": 0, - "data_array": [], - }, - "next_page_token": "", - } - vectorsearch = default_databricks_vector_search(index) - - search_result = vectorsearch.similarity_search("foo") - assert len(search_result) == 0 - - -@pytest.mark.requires("databricks", "databricks.vector_search") -def test_similarity_search_by_vector_not_supported_for_managed_embedding() -> None: - index = mock_index(DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS) - index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE - vectorsearch = default_databricks_vector_search(index) - query_embedding = DEFAULT_EMBEDDING_MODEL.embed_query("foo") - filters = {"some filter": True} - limit = 7 - - with pytest.raises(ValueError) as ex: - vectorsearch.similarity_search_by_vector( - query_embedding, k=limit, filters=filters - ) - assert ( - "`similarity_search_by_vector` is not supported for index with " - "Databricks-managed embeddings." in str(ex.value) - ) - - -@pytest.mark.requires("databricks", "databricks.vector_search") -@pytest.mark.parametrize( - "method", - [ - "similarity_search", - "similarity_search_with_score", - "similarity_search_by_vector", - "similarity_search_by_vector_with_score", - "max_marginal_relevance_search", - "max_marginal_relevance_search_by_vector", - ], -) -def test_filter_arg_alias(method: str) -> None: - index = mock_index(DIRECT_ACCESS_INDEX) - vectorsearch = default_databricks_vector_search(index) - query = "foo" - query_embedding = DEFAULT_EMBEDDING_MODEL.embed_query("foo") - filters = {"some filter": True} - limit = 7 - - if "by_vector" in method: - getattr(vectorsearch, method)(query_embedding, k=limit, filters=filters) - else: - getattr(vectorsearch, method)(query, k=limit, filters=filters) - - index_call_args = index.similarity_search.call_args[1] - assert index_call_args["filters"] == filters diff --git a/libs/community/tests/unit_tests/vectorstores/test_hanavector.py b/libs/community/tests/unit_tests/vectorstores/test_hanavector.py deleted file mode 100644 index 6eab86d33..000000000 --- a/libs/community/tests/unit_tests/vectorstores/test_hanavector.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Test HanaVector functionality.""" - -from langchain_community.vectorstores import HanaDB - - -def test_int_sanitation_with_illegal_value() -> None: - """Test sanitization of int with illegal value""" - successful = True - try: - HanaDB._sanitize_int("HUGO") - successful = False - except ValueError: - pass - - assert successful - - -def test_int_sanitation_with_legal_values() -> None: - """Test sanitization of int with legal values""" - assert HanaDB._sanitize_int(42) == 42 - - assert HanaDB._sanitize_int("21") == 21 - - -def test_int_sanitation_with_negative_values() -> None: - """Test sanitization of int with legal values""" - assert HanaDB._sanitize_int(-1) == -1 - - assert HanaDB._sanitize_int("-1") == -1 - - -def test_int_sanitation_with_illegal_negative_value() -> None: - """Test sanitization of int with illegal value""" - successful = True - try: - HanaDB._sanitize_int(-2) - successful = False - except ValueError: - pass - - assert successful - - -def test_parse_float_array_from_string() -> None: - array_as_string = "[0.1, 0.2, 0.3]" - assert HanaDB._parse_float_array_from_string(array_as_string) == [0.1, 0.2, 0.3] diff --git a/libs/community/tests/unit_tests/vectorstores/test_imports.py b/libs/community/tests/unit_tests/vectorstores/test_imports.py index 7ad483af9..bc2920558 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_imports.py +++ b/libs/community/tests/unit_tests/vectorstores/test_imports.py @@ -28,8 +28,6 @@ "ClickhouseSettings", "CouchbaseVectorStore", "DashVector", - "DatabricksVectorSearch", - "DeepLake", "Dingo", "DistanceStrategy", "DocArrayHnswSearch", @@ -42,7 +40,6 @@ "ElasticsearchStore", "Epsilla", "FAISS", - "HanaDB", "Hologres", "InMemoryVectorStore", "InfinispanVS", @@ -55,23 +52,16 @@ "ManticoreSearch", "ManticoreSearchSettings", "Marqo", - "MatchingEngine", "Meilisearch", - "Milvus", "MomentoVectorIndex", - "MongoDBAtlasVectorSearch", "MyScale", "MyScaleSettings", - "Neo4jVector", "NeuralDBClientVectorStore", "NeuralDBVectorStore", "OpenSearchVectorSearch", - "OracleVS", "PGEmbedding", "PGVector", "PathwayVectorClient", - "Pinecone", - "Qdrant", "Redis", "Relyt", "Rockset", @@ -93,14 +83,12 @@ "Typesense", "UpstashVectorStore", "USearch", - "VDMS", "Vald", "Vearch", "Vectara", "VectorStore", "VespaStore", "VLite", - "Weaviate", "Yellowbrick", "ZepVectorStore", "ZepCloudVectorStore", diff --git a/libs/community/tests/unit_tests/vectorstores/test_indexing_docs.py b/libs/community/tests/unit_tests/vectorstores/test_indexing_docs.py index b7da470ef..cdc456fa5 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_indexing_docs.py +++ b/libs/community/tests/unit_tests/vectorstores/test_indexing_docs.py @@ -61,25 +61,18 @@ def check_compatibility(vector_store: VectorStore) -> bool: "Chroma", "CouchbaseVectorStore", "DashVector", - "DatabricksVectorSearch", "TiDBVectorStore", - "DeepLake", "Dingo", "DocumentDBVectorSearch", "ElasticVectorSearch", "ElasticsearchStore", "FAISS", - "HanaDB", "InMemoryVectorStore", "LanceDB", - "Milvus", "MomentoVectorIndex", "MyScale", "OpenSearchVectorSearch", - "OracleVS", "PGVector", - "Pinecone", - "Qdrant", "Redis", "Relyt", "Rockset", @@ -95,12 +88,10 @@ def check_compatibility(vector_store: VectorStore) -> bool: "UpstashVectorStore", "EcloudESVectorStore", "Vald", - "VDMS", "Vearch", "Vectara", "VespaStore", "VLite", - "Weaviate", "Yellowbrick", "ZepVectorStore", "ZepCloudVectorStore", diff --git a/libs/community/tests/unit_tests/vectorstores/test_neo4j.py b/libs/community/tests/unit_tests/vectorstores/test_neo4j.py deleted file mode 100644 index 1bc85d1bb..000000000 --- a/libs/community/tests/unit_tests/vectorstores/test_neo4j.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Test Neo4j functionality.""" - -from langchain_community.vectorstores.neo4j_vector import ( - dict_to_yaml_str, - remove_lucene_chars, -) - - -def test_escaping_lucene() -> None: - """Test escaping lucene characters""" - assert remove_lucene_chars("Hello+World") == "Hello World" - assert remove_lucene_chars("Hello World\\") == "Hello World" - assert ( - remove_lucene_chars("It is the end of the world. Take shelter!") - == "It is the end of the world. Take shelter" - ) - assert ( - remove_lucene_chars("It is the end of the world. Take shelter&&") - == "It is the end of the world. Take shelter" - ) - assert ( - remove_lucene_chars("Bill&&Melinda Gates Foundation") - == "Bill Melinda Gates Foundation" - ) - assert ( - remove_lucene_chars("It is the end of the world. Take shelter(&&)") - == "It is the end of the world. Take shelter" - ) - assert ( - remove_lucene_chars("It is the end of the world. Take shelter??") - == "It is the end of the world. Take shelter" - ) - assert ( - remove_lucene_chars("It is the end of the world. Take shelter^") - == "It is the end of the world. Take shelter" - ) - assert ( - remove_lucene_chars("It is the end of the world. Take shelter+") - == "It is the end of the world. Take shelter" - ) - assert ( - remove_lucene_chars("It is the end of the world. Take shelter-") - == "It is the end of the world. Take shelter" - ) - assert ( - remove_lucene_chars("It is the end of the world. Take shelter~") - == "It is the end of the world. Take shelter" - ) - - -def test_converting_to_yaml() -> None: - example_dict = { - "name": "John Doe", - "age": 30, - "skills": ["Python", "Data Analysis", "Machine Learning"], - "location": {"city": "Ljubljana", "country": "Slovenia"}, - } - - yaml_str = dict_to_yaml_str(example_dict) - - expected_output = ( - "name: John Doe\nage: 30\nskills:\n- Python\n- " - "Data Analysis\n- Machine Learning\nlocation:\n city: Ljubljana\n" - " country: Slovenia\n" - ) - - assert yaml_str == expected_output diff --git a/libs/community/uv.lock b/libs/community/uv.lock index 90ea02b91..dafb00199 100644 --- a/libs/community/uv.lock +++ b/libs/community/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10.0, <4.0.0" resolution-markers = [ "python_full_version >= '3.14' and platform_python_implementation == 'PyPy'", @@ -1001,6 +1001,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7f/91/ae2eb6b7979e2f9b035a9f612cf70f1bf54aad4e1d125129bef1eae96f19/greenlet-3.2.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c2ca18a03a8cfb5b25bc1cbe20f3d9a4c80d8c3b13ba3df49ac3961af0b1018d", size = 584358, upload-time = "2025-08-07T13:18:23.708Z" }, { url = "https://files.pythonhosted.org/packages/f7/85/433de0c9c0252b22b16d413c9407e6cb3b41df7389afc366ca204dbc1393/greenlet-3.2.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9fe0a28a7b952a21e2c062cd5756d34354117796c6d9215a87f55e38d15402c5", size = 1113550, upload-time = "2025-08-07T13:42:37.467Z" }, { url = "https://files.pythonhosted.org/packages/a1/8d/88f3ebd2bc96bf7747093696f4335a0a8a4c5acfcf1b757717c0d2474ba3/greenlet-3.2.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8854167e06950ca75b898b104b63cc646573aa5fef1353d4508ecdd1ee76254f", size = 1137126, upload-time = "2025-08-07T13:18:20.239Z" }, + { url = "https://files.pythonhosted.org/packages/f1/29/74242b7d72385e29bcc5563fba67dad94943d7cd03552bac320d597f29b2/greenlet-3.2.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f47617f698838ba98f4ff4189aef02e7343952df3a615f847bb575c3feb177a7", size = 1544904, upload-time = "2025-11-04T12:42:04.763Z" }, + { url = "https://files.pythonhosted.org/packages/c8/e2/1572b8eeab0f77df5f6729d6ab6b141e4a84ee8eb9bc8c1e7918f94eda6d/greenlet-3.2.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:af41be48a4f60429d5cad9d22175217805098a9ef7c40bfef44f7669fb9d74d8", size = 1611228, upload-time = "2025-11-04T12:42:08.423Z" }, { url = "https://files.pythonhosted.org/packages/d6/6f/b60b0291d9623c496638c582297ead61f43c4b72eef5e9c926ef4565ec13/greenlet-3.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:73f49b5368b5359d04e18d15828eecc1806033db5233397748f4ca813ff1056c", size = 298654, upload-time = "2025-08-07T13:50:00.469Z" }, { url = "https://files.pythonhosted.org/packages/a4/de/f28ced0a67749cac23fecb02b694f6473f47686dff6afaa211d186e2ef9c/greenlet-3.2.4-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:96378df1de302bc38e99c3a9aa311967b7dc80ced1dcc6f171e99842987882a2", size = 272305, upload-time = "2025-08-07T13:15:41.288Z" }, { url = "https://files.pythonhosted.org/packages/09/16/2c3792cba130000bf2a31c5272999113f4764fd9d874fb257ff588ac779a/greenlet-3.2.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1ee8fae0519a337f2329cb78bd7a8e128ec0f881073d43f023c7b8d4831d5246", size = 632472, upload-time = "2025-08-07T13:42:55.044Z" }, @@ -1010,6 +1012,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/8e/abdd3f14d735b2929290a018ecf133c901be4874b858dd1c604b9319f064/greenlet-3.2.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2523e5246274f54fdadbce8494458a2ebdcdbc7b802318466ac5606d3cded1f8", size = 587684, upload-time = "2025-08-07T13:18:25.164Z" }, { url = "https://files.pythonhosted.org/packages/5d/65/deb2a69c3e5996439b0176f6651e0052542bb6c8f8ec2e3fba97c9768805/greenlet-3.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1987de92fec508535687fb807a5cea1560f6196285a4cde35c100b8cd632cc52", size = 1116647, upload-time = "2025-08-07T13:42:38.655Z" }, { url = "https://files.pythonhosted.org/packages/3f/cc/b07000438a29ac5cfb2194bfc128151d52f333cee74dd7dfe3fb733fc16c/greenlet-3.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:55e9c5affaa6775e2c6b67659f3a71684de4c549b3dd9afca3bc773533d284fa", size = 1142073, upload-time = "2025-08-07T13:18:21.737Z" }, + { url = "https://files.pythonhosted.org/packages/67/24/28a5b2fa42d12b3d7e5614145f0bd89714c34c08be6aabe39c14dd52db34/greenlet-3.2.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c9c6de1940a7d828635fbd254d69db79e54619f165ee7ce32fda763a9cb6a58c", size = 1548385, upload-time = "2025-11-04T12:42:11.067Z" }, + { url = "https://files.pythonhosted.org/packages/6a/05/03f2f0bdd0b0ff9a4f7b99333d57b53a7709c27723ec8123056b084e69cd/greenlet-3.2.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03c5136e7be905045160b1b9fdca93dd6727b180feeafda6818e6496434ed8c5", size = 1613329, upload-time = "2025-11-04T12:42:12.928Z" }, { url = "https://files.pythonhosted.org/packages/d8/0f/30aef242fcab550b0b3520b8e3561156857c94288f0332a79928c31a52cf/greenlet-3.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:9c40adce87eaa9ddb593ccb0fa6a07caf34015a29bf8d344811665b573138db9", size = 299100, upload-time = "2025-08-07T13:44:12.287Z" }, { url = "https://files.pythonhosted.org/packages/44/69/9b804adb5fd0671f367781560eb5eb586c4d495277c93bde4307b9e28068/greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd", size = 274079, upload-time = "2025-08-07T13:15:45.033Z" }, { url = "https://files.pythonhosted.org/packages/46/e9/d2a80c99f19a153eff70bc451ab78615583b8dac0754cfb942223d2c1a0d/greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb", size = 640997, upload-time = "2025-08-07T13:42:56.234Z" }, @@ -1019,6 +1023,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/19/0d/6660d55f7373b2ff8152401a83e02084956da23ae58cddbfb0b330978fe9/greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0", size = 607586, upload-time = "2025-08-07T13:18:28.544Z" }, { url = "https://files.pythonhosted.org/packages/8e/1a/c953fdedd22d81ee4629afbb38d2f9d71e37d23caace44775a3a969147d4/greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0", size = 1123281, upload-time = "2025-08-07T13:42:39.858Z" }, { url = "https://files.pythonhosted.org/packages/3f/c7/12381b18e21aef2c6bd3a636da1088b888b97b7a0362fac2e4de92405f97/greenlet-3.2.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20fb936b4652b6e307b8f347665e2c615540d4b42b3b4c8a321d8286da7e520f", size = 1151142, upload-time = "2025-08-07T13:18:22.981Z" }, + { url = "https://files.pythonhosted.org/packages/27/45/80935968b53cfd3f33cf99ea5f08227f2646e044568c9b1555b58ffd61c2/greenlet-3.2.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ee7a6ec486883397d70eec05059353b8e83eca9168b9f3f9a361971e77e0bcd0", size = 1564846, upload-time = "2025-11-04T12:42:15.191Z" }, + { url = "https://files.pythonhosted.org/packages/69/02/b7c30e5e04752cb4db6202a3858b149c0710e5453b71a3b2aec5d78a1aab/greenlet-3.2.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:326d234cbf337c9c3def0676412eb7040a35a768efc92504b947b3e9cfc7543d", size = 1633814, upload-time = "2025-11-04T12:42:17.175Z" }, { url = "https://files.pythonhosted.org/packages/e9/08/b0814846b79399e585f974bbeebf5580fbe59e258ea7be64d9dfb253c84f/greenlet-3.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:a7d4e128405eea3814a12cc2605e0e6aedb4035bf32697f72deca74de4105e02", size = 299899, upload-time = "2025-08-07T13:38:53.448Z" }, { url = "https://files.pythonhosted.org/packages/49/e8/58c7f85958bda41dafea50497cbd59738c5c43dbbea5ee83d651234398f4/greenlet-3.2.4-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:1a921e542453fe531144e91e1feedf12e07351b1cf6c9e8a3325ea600a715a31", size = 272814, upload-time = "2025-08-07T13:15:50.011Z" }, { url = "https://files.pythonhosted.org/packages/62/dd/b9f59862e9e257a16e4e610480cfffd29e3fae018a68c2332090b53aac3d/greenlet-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd3c8e693bff0fff6ba55f140bf390fa92c994083f838fece0f63be121334945", size = 641073, upload-time = "2025-08-07T13:42:57.23Z" }, @@ -1028,6 +1034,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/43/3cecdc0349359e1a527cbf2e3e28e5f8f06d3343aaf82ca13437a9aa290f/greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671", size = 610497, upload-time = "2025-08-07T13:18:31.636Z" }, { url = "https://files.pythonhosted.org/packages/b8/19/06b6cf5d604e2c382a6f31cafafd6f33d5dea706f4db7bdab184bad2b21d/greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b", size = 1121662, upload-time = "2025-08-07T13:42:41.117Z" }, { url = "https://files.pythonhosted.org/packages/a2/15/0d5e4e1a66fab130d98168fe984c509249c833c1a3c16806b90f253ce7b9/greenlet-3.2.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d25c5091190f2dc0eaa3f950252122edbbadbb682aa7b1ef2f8af0f8c0afefae", size = 1149210, upload-time = "2025-08-07T13:18:24.072Z" }, + { url = "https://files.pythonhosted.org/packages/1c/53/f9c440463b3057485b8594d7a638bed53ba531165ef0ca0e6c364b5cc807/greenlet-3.2.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6e343822feb58ac4d0a1211bd9399de2b3a04963ddeec21530fc426cc121f19b", size = 1564759, upload-time = "2025-11-04T12:42:19.395Z" }, + { url = "https://files.pythonhosted.org/packages/47/e4/3bb4240abdd0a8d23f4f88adec746a3099f0d86bfedb623f063b2e3b4df0/greenlet-3.2.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ca7f6f1f2649b89ce02f6f229d7c19f680a6238af656f61e0115b24857917929", size = 1634288, upload-time = "2025-11-04T12:42:21.174Z" }, { url = "https://files.pythonhosted.org/packages/0b/55/2321e43595e6801e105fcfdee02b34c0f996eb71e6ddffca6b10b7e1d771/greenlet-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:554b03b6e73aaabec3745364d6239e9e012d64c68ccd0b8430c64ccc14939a8b", size = 299685, upload-time = "2025-08-07T13:24:38.824Z" }, { url = "https://files.pythonhosted.org/packages/22/5c/85273fd7cc388285632b0498dbbab97596e04b154933dfe0f3e68156c68c/greenlet-3.2.4-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:49a30d5fda2507ae77be16479bdb62a660fa51b1eb4928b524975b3bde77b3c0", size = 273586, upload-time = "2025-08-07T13:16:08.004Z" }, { url = "https://files.pythonhosted.org/packages/d1/75/10aeeaa3da9332c2e761e4c50d4c3556c21113ee3f0afa2cf5769946f7a3/greenlet-3.2.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:299fd615cd8fc86267b47597123e3f43ad79c9d8a22bebdce535e53550763e2f", size = 686346, upload-time = "2025-08-07T13:42:59.944Z" }, @@ -1035,6 +1043,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dc/8b/29aae55436521f1d6f8ff4e12fb676f3400de7fcf27fccd1d4d17fd8fecd/greenlet-3.2.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b4a1870c51720687af7fa3e7cda6d08d801dae660f75a76f3845b642b4da6ee1", size = 694659, upload-time = "2025-08-07T13:53:17.759Z" }, { url = "https://files.pythonhosted.org/packages/92/2e/ea25914b1ebfde93b6fc4ff46d6864564fba59024e928bdc7de475affc25/greenlet-3.2.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:061dc4cf2c34852b052a8620d40f36324554bc192be474b9e9770e8c042fd735", size = 695355, upload-time = "2025-08-07T13:18:34.517Z" }, { url = "https://files.pythonhosted.org/packages/72/60/fc56c62046ec17f6b0d3060564562c64c862948c9d4bc8aa807cf5bd74f4/greenlet-3.2.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:44358b9bf66c8576a9f57a590d5f5d6e72fa4228b763d0e43fee6d3b06d3a337", size = 657512, upload-time = "2025-08-07T13:18:33.969Z" }, + { url = "https://files.pythonhosted.org/packages/23/6e/74407aed965a4ab6ddd93a7ded3180b730d281c77b765788419484cdfeef/greenlet-3.2.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:2917bdf657f5859fbf3386b12d68ede4cf1f04c90c3a6bc1f013dd68a22e2269", size = 1612508, upload-time = "2025-11-04T12:42:23.427Z" }, + { url = "https://files.pythonhosted.org/packages/0d/da/343cd760ab2f92bac1845ca07ee3faea9fe52bee65f7bcb19f16ad7de08b/greenlet-3.2.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:015d48959d4add5d6c9f6c5210ee3803a830dce46356e3bc326d6776bde54681", size = 1680760, upload-time = "2025-11-04T12:42:25.341Z" }, { url = "https://files.pythonhosted.org/packages/e3/a5/6ddab2b4c112be95601c13428db1d8b6608a8b6039816f2ba09c346c08fc/greenlet-3.2.4-cp314-cp314-win_amd64.whl", hash = "sha256:e37ab26028f12dbb0ff65f29a8d3d44a765c61e729647bf2ddfbbed621726f01", size = 303425, upload-time = "2025-08-07T13:32:27.59Z" }, ]