Skip to content

Commit cdf6202

Browse files
gsa9989aayush3011ccurme
authored
cosmosdbnosql: Added Cosmos DB NoSQL Semantic Cache Integration with tests and jupyter notebook (langchain-ai#24424)
* Added Cosmos DB NoSQL Semantic Cache Integration with tests and jupyter notebook --------- Co-authored-by: Aayush Kataria <[email protected]> Co-authored-by: Chester Curme <[email protected]>
1 parent 27a9056 commit cdf6202

File tree

6 files changed

+495
-81
lines changed

6 files changed

+495
-81
lines changed

docs/docs/integrations/llm_caching.ipynb

Lines changed: 149 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
},
1515
{
1616
"cell_type": "code",
17-
"execution_count": 2,
18-
"id": "88486f6f",
17+
"execution_count": null,
18+
"id": "f938e881",
1919
"metadata": {},
2020
"outputs": [],
2121
"source": [
@@ -30,12 +30,12 @@
3030
},
3131
{
3232
"cell_type": "code",
33-
"execution_count": 3,
33+
"execution_count": 2,
3434
"id": "10ad9224",
3535
"metadata": {
3636
"ExecuteTime": {
37-
"end_time": "2024-04-12T02:05:57.319706Z",
38-
"start_time": "2024-04-12T02:05:57.303868Z"
37+
"end_time": "2024-12-06T00:54:06.474593Z",
38+
"start_time": "2024-12-06T00:53:58.727138Z"
3939
}
4040
},
4141
"outputs": [],
@@ -1820,7 +1820,7 @@
18201820
},
18211821
{
18221822
"cell_type": "code",
1823-
"execution_count": 83,
1823+
"execution_count": null,
18241824
"id": "bc1570a2a77b58c8",
18251825
"metadata": {
18261826
"ExecuteTime": {
@@ -1848,12 +1848,155 @@
18481848
"output_type": "execute_result"
18491849
}
18501850
],
1851+
"source": [
1852+
"%%time\n",
1853+
"# The second time it is, so it goes faster\n",
1854+
"llm.invoke(\"Tell me a joke\")"
1855+
]
1856+
},
1857+
{
1858+
"cell_type": "markdown",
1859+
"id": "235ff73bf7143f13",
1860+
"metadata": {},
1861+
"source": [
1862+
"## Azure CosmosDB NoSql Semantic Cache\n",
1863+
"\n",
1864+
"You can use this integrated [vector database](https://learn.microsoft.com/en-us/azure/cosmos-db/vector-database) for caching."
1865+
]
1866+
},
1867+
{
1868+
"cell_type": "code",
1869+
"execution_count": null,
1870+
"id": "41fea5aa7b2153ca",
1871+
"metadata": {
1872+
"ExecuteTime": {
1873+
"end_time": "2024-12-06T00:55:38.648972Z",
1874+
"start_time": "2024-12-06T00:55:38.290541Z"
1875+
}
1876+
},
1877+
"outputs": [],
1878+
"source": [
1879+
"from typing import Any, Dict\n",
1880+
"\n",
1881+
"from azure.cosmos import CosmosClient, PartitionKey\n",
1882+
"from langchain_community.cache import AzureCosmosDBNoSqlSemanticCache\n",
1883+
"from langchain_openai import OpenAIEmbeddings\n",
1884+
"\n",
1885+
"HOST = \"COSMOS_DB_URI\"\n",
1886+
"KEY = \"COSMOS_DB_KEY\"\n",
1887+
"\n",
1888+
"cosmos_client = CosmosClient(HOST, KEY)\n",
1889+
"\n",
1890+
"\n",
1891+
"def get_vector_indexing_policy() -> dict:\n",
1892+
" return {\n",
1893+
" \"indexingMode\": \"consistent\",\n",
1894+
" \"includedPaths\": [{\"path\": \"/*\"}],\n",
1895+
" \"excludedPaths\": [{\"path\": '/\"_etag\"/?'}],\n",
1896+
" \"vectorIndexes\": [{\"path\": \"/embedding\", \"type\": \"diskANN\"}],\n",
1897+
" }\n",
1898+
"\n",
1899+
"\n",
1900+
"def get_vector_embedding_policy() -> dict:\n",
1901+
" return {\n",
1902+
" \"vectorEmbeddings\": [\n",
1903+
" {\n",
1904+
" \"path\": \"/embedding\",\n",
1905+
" \"dataType\": \"float32\",\n",
1906+
" \"dimensions\": 1536,\n",
1907+
" \"distanceFunction\": \"cosine\",\n",
1908+
" }\n",
1909+
" ]\n",
1910+
" }\n",
1911+
"\n",
1912+
"\n",
1913+
"cosmos_container_properties_test = {\"partition_key\": PartitionKey(path=\"/id\")}\n",
1914+
"cosmos_database_properties_test: Dict[str, Any] = {}\n",
1915+
"\n",
1916+
"set_llm_cache(\n",
1917+
" AzureCosmosDBNoSqlSemanticCache(\n",
1918+
" cosmos_client=cosmos_client,\n",
1919+
" embedding=OpenAIEmbeddings(),\n",
1920+
" vector_embedding_policy=get_vector_embedding_policy(),\n",
1921+
" indexing_policy=get_vector_indexing_policy(),\n",
1922+
" cosmos_container_properties=cosmos_container_properties_test,\n",
1923+
" cosmos_database_properties=cosmos_database_properties_test,\n",
1924+
" )\n",
1925+
")"
1926+
]
1927+
},
1928+
{
1929+
"cell_type": "code",
1930+
"execution_count": 6,
1931+
"id": "1e1cd93819921bf6",
1932+
"metadata": {
1933+
"ExecuteTime": {
1934+
"end_time": "2024-12-06T00:55:44.513080Z",
1935+
"start_time": "2024-12-06T00:55:41.353843Z"
1936+
}
1937+
},
1938+
"outputs": [
1939+
{
1940+
"name": "stdout",
1941+
"output_type": "stream",
1942+
"text": [
1943+
"CPU times: user 374 ms, sys: 34.2 ms, total: 408 ms\n",
1944+
"Wall time: 3.15 s\n"
1945+
]
1946+
},
1947+
{
1948+
"data": {
1949+
"text/plain": [
1950+
"\"\\n\\nWhy couldn't the bicycle stand up by itself? Because it was two-tired!\""
1951+
]
1952+
},
1953+
"execution_count": 6,
1954+
"metadata": {},
1955+
"output_type": "execute_result"
1956+
}
1957+
],
18511958
"source": [
18521959
"%%time\n",
18531960
"# The first time, it is not yet in cache, so it should take longer\n",
18541961
"llm.invoke(\"Tell me a joke\")"
18551962
]
18561963
},
1964+
{
1965+
"cell_type": "code",
1966+
"execution_count": null,
1967+
"id": "576ce24c1244812a",
1968+
"metadata": {
1969+
"ExecuteTime": {
1970+
"end_time": "2024-12-06T00:55:50.925865Z",
1971+
"start_time": "2024-12-06T00:55:50.548520Z"
1972+
}
1973+
},
1974+
"outputs": [
1975+
{
1976+
"name": "stdout",
1977+
"output_type": "stream",
1978+
"text": [
1979+
"CPU times: user 17.7 ms, sys: 2.88 ms, total: 20.6 ms\n",
1980+
"Wall time: 373 ms\n"
1981+
]
1982+
},
1983+
{
1984+
"data": {
1985+
"text/plain": [
1986+
"\"\\n\\nWhy couldn't the bicycle stand up by itself? Because it was two-tired!\""
1987+
]
1988+
},
1989+
"execution_count": 8,
1990+
"metadata": {},
1991+
"output_type": "execute_result"
1992+
}
1993+
],
1994+
"source": [
1995+
"%%time\n",
1996+
"# The second time it is, so it goes faster\n",
1997+
"llm.invoke(\"Tell me a joke\")"
1998+
]
1999+
},
18572000
{
18582001
"cell_type": "markdown",
18592002
"id": "306ff47b",

libs/community/langchain_community/cache.py

Lines changed: 111 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@
8080
from langchain_community.utilities.astradb import (
8181
_AstraDBCollectionEnvironment,
8282
)
83-
from langchain_community.vectorstores import AzureCosmosDBVectorSearch
83+
from langchain_community.vectorstores import (
84+
AzureCosmosDBNoSqlVectorSearch,
85+
AzureCosmosDBVectorSearch,
86+
)
8487
from langchain_community.vectorstores import (
8588
OpenSearchVectorSearch as OpenSearchVectorStore,
8689
)
@@ -93,6 +96,7 @@
9396
import momento
9497
import pymemcache
9598
from astrapy.db import AstraDB, AsyncAstraDB
99+
from azure.cosmos.cosmos_client import CosmosClient
96100
from cassandra.cluster import Session as CassandraSession
97101

98102

@@ -2103,7 +2107,7 @@ def __init__(
21032107
ef_construction: int = 64,
21042108
ef_search: int = 40,
21052109
score_threshold: Optional[float] = None,
2106-
application_name: str = "LANGCHAIN_CACHING_PYTHON",
2110+
application_name: str = "LangChain-CDBMongoVCore-SemanticCache-Python",
21072111
):
21082112
"""
21092113
Args:
@@ -2268,14 +2272,118 @@ def clear(self, **kwargs: Any) -> None:
22682272
index_name = self._index_name(kwargs["llm_string"])
22692273
if index_name in self._cache_dict:
22702274
self._cache_dict[index_name].get_collection().delete_many({})
2271-
# self._cache_dict[index_name].clear_collection()
22722275

22732276
@staticmethod
22742277
def _validate_enum_value(value: Any, enum_type: Type[Enum]) -> None:
22752278
if not isinstance(value, enum_type):
22762279
raise ValueError(f"Invalid enum value: {value}. Expected {enum_type}.")
22772280

22782281

2282+
class AzureCosmosDBNoSqlSemanticCache(BaseCache):
2283+
"""Cache that uses Cosmos DB NoSQL backend"""
2284+
2285+
def __init__(
2286+
self,
2287+
embedding: Embeddings,
2288+
cosmos_client: CosmosClient,
2289+
database_name: str = "CosmosNoSqlCacheDB",
2290+
container_name: str = "CosmosNoSqlCacheContainer",
2291+
*,
2292+
vector_embedding_policy: Dict[str, Any],
2293+
indexing_policy: Dict[str, Any],
2294+
cosmos_container_properties: Dict[str, Any],
2295+
cosmos_database_properties: Dict[str, Any],
2296+
create_container: bool = True,
2297+
):
2298+
self.cosmos_client = cosmos_client
2299+
self.database_name = database_name
2300+
self.container_name = container_name
2301+
self.embedding = embedding
2302+
self.vector_embedding_policy = vector_embedding_policy
2303+
self.indexing_policy = indexing_policy
2304+
self.cosmos_container_properties = cosmos_container_properties
2305+
self.cosmos_database_properties = cosmos_database_properties
2306+
self.create_container = create_container
2307+
self._cache_dict: Dict[str, AzureCosmosDBNoSqlVectorSearch] = {}
2308+
2309+
def _cache_name(self, llm_string: str) -> str:
2310+
hashed_index = _hash(llm_string)
2311+
return f"cache:{hashed_index}"
2312+
2313+
def _get_llm_cache(self, llm_string: str) -> AzureCosmosDBNoSqlVectorSearch:
2314+
cache_name = self._cache_name(llm_string)
2315+
2316+
# return vectorstore client for the specific llm string
2317+
if cache_name in self._cache_dict:
2318+
return self._cache_dict[cache_name]
2319+
2320+
# create new vectorstore client to create the cache
2321+
if self.cosmos_client:
2322+
self._cache_dict[cache_name] = AzureCosmosDBNoSqlVectorSearch(
2323+
cosmos_client=self.cosmos_client,
2324+
embedding=self.embedding,
2325+
vector_embedding_policy=self.vector_embedding_policy,
2326+
indexing_policy=self.indexing_policy,
2327+
cosmos_container_properties=self.cosmos_container_properties,
2328+
cosmos_database_properties=self.cosmos_database_properties,
2329+
database_name=self.database_name,
2330+
container_name=self.container_name,
2331+
create_container=self.create_container,
2332+
)
2333+
2334+
return self._cache_dict[cache_name]
2335+
2336+
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
2337+
"""Look up based on prompt."""
2338+
llm_cache = self._get_llm_cache(llm_string)
2339+
generations: List = []
2340+
# Read from a Hash
2341+
results = llm_cache.similarity_search(
2342+
query=prompt,
2343+
k=1,
2344+
)
2345+
if results:
2346+
for document in results:
2347+
try:
2348+
generations.extend(loads(document.metadata["return_val"]))
2349+
except Exception:
2350+
logger.warning(
2351+
"Retrieving a cache value that could not be deserialized "
2352+
"properly. This is likely due to the cache being in an "
2353+
"older format. Please recreate your cache to avoid this "
2354+
"error."
2355+
)
2356+
2357+
generations.extend(
2358+
_load_generations_from_json(document.metadata["return_val"])
2359+
)
2360+
return generations if generations else None
2361+
2362+
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
2363+
"""Update cache based on prompt and llm_string."""
2364+
for gen in return_val:
2365+
if not isinstance(gen, Generation):
2366+
raise ValueError(
2367+
"CosmosDBNoSqlSemanticCache only supports caching of "
2368+
f"normal LLM generations, got {type(gen)}"
2369+
)
2370+
llm_cache = self._get_llm_cache(llm_string)
2371+
metadata = {
2372+
"llm_string": llm_string,
2373+
"prompt": prompt,
2374+
"return_val": dumps([g for g in return_val]),
2375+
}
2376+
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
2377+
2378+
def clear(self, **kwargs: Any) -> None:
2379+
"""Clear semantic cache for a given llm_string."""
2380+
cache_name = self._cache_name(llm_string=kwargs["llm-string"])
2381+
if cache_name in self._cache_dict:
2382+
container = self._cache_dict["cache_name"].get_container()
2383+
for item in container.read_all_items():
2384+
container.delete_item(item)
2385+
2386+
22792387
class OpenSearchSemanticCache(BaseCache):
22802388
"""Cache that uses OpenSearch vector store backend"""
22812389

libs/community/langchain_community/vectorstores/azure_cosmos_db.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
index_name: str = "vectorSearchIndex",
8383
text_key: str = "textContent",
8484
embedding_key: str = "vectorContent",
85-
application_name: str = "LANGCHAIN_PYTHON",
85+
application_name: str = "LangChain-CDBMongoVCore-VectorStore-Python",
8686
):
8787
"""Constructor for AzureCosmosDBVectorSearch
8888
@@ -121,7 +121,7 @@ def from_connection_string(
121121
connection_string: str,
122122
namespace: str,
123123
embedding: Embeddings,
124-
application_name: str = "LANGCHAIN_PYTHON",
124+
application_name: str = "LangChain-CDBMongoVCore-VectorStore-Python",
125125
**kwargs: Any,
126126
) -> AzureCosmosDBVectorSearch:
127127
"""Creates an Instance of AzureCosmosDBVectorSearch

libs/community/langchain_community/vectorstores/azure_cosmos_db_no_sql.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from langchain_community.vectorstores.utils import maximal_marginal_relevance
1515

1616
if TYPE_CHECKING:
17-
from azure.cosmos import CosmosClient
17+
from azure.cosmos import ContainerProxy, CosmosClient
1818
from azure.identity import DefaultAzureCredential
1919

2020
USER_AGENT = ("LangChain-CDBNoSql-VectorStore-Python",)
@@ -859,3 +859,6 @@ def _where_clause_operator_map(self) -> Dict[str, str]:
859859
"$full_text_contains_any": "FullTextContainsAny",
860860
}
861861
return operator_map
862+
863+
def get_container(self) -> ContainerProxy:
864+
return self._container

0 commit comments

Comments
 (0)