Skip to content

Commit 9f2ba5b

Browse files
authored
feat: add exact match caching (#1717)
This PR introduces a robust caching abstraction to improve performance and prevent progress loss during API calls (e.g., LLM or embedding requests). This was brought up in #1522 and later added as an enhancement request in #1602.
1 parent cf06513 commit 9f2ba5b

File tree

9 files changed

+290
-3
lines changed

9 files changed

+290
-3
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,4 @@ src/ragas/_version.py
171171
.vscode
172172
.envrc
173173
uv.lock
174+
.cache/

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies = [
1414
"pydantic>=2",
1515
"openai>1",
1616
"pysbd>=0.3.4",
17+
"diskcache>=5.6.3",
1718
]
1819
dynamic = ["version", "readme"]
1920

requirements/test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ pytest-xdist[psutil]
33
pytest-asyncio
44
llama_index
55
nbmake
6+
diskcache

src/ragas/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from ragas.cache import CacheInterface, DiskCacheBackend, cacher
12
from ragas.dataset_schema import EvaluationDataset, MultiTurnSample, SingleTurnSample
23
from ragas.evaluation import evaluate
34
from ragas.run_config import RunConfig
@@ -15,4 +16,7 @@
1516
"SingleTurnSample",
1617
"MultiTurnSample",
1718
"EvaluationDataset",
19+
"cacher",
20+
"CacheInterface",
21+
"DiskCacheBackend",
1822
]

src/ragas/cache.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import functools
2+
import hashlib
3+
import inspect
4+
import json
5+
from abc import ABC, abstractmethod
6+
from typing import Any, Optional
7+
8+
from pydantic import BaseModel
9+
10+
11+
class CacheInterface(ABC):
12+
@abstractmethod
13+
def get(self, key: str) -> Any:
14+
pass
15+
16+
@abstractmethod
17+
def set(self, key: str, value) -> None:
18+
pass
19+
20+
@abstractmethod
21+
def has_key(self, key: str) -> bool:
22+
pass
23+
24+
25+
class DiskCacheBackend(CacheInterface):
26+
def __init__(self, cache_dir: str = ".cache"):
27+
try:
28+
from diskcache import Cache
29+
except ImportError:
30+
raise ImportError(
31+
"For using the diskcache backend, please install it with `pip install diskcache`."
32+
)
33+
34+
self.cache = Cache(cache_dir)
35+
36+
def get(self, key: str) -> Any:
37+
return self.cache.get(key)
38+
39+
def set(self, key: str, value) -> None:
40+
self.cache.set(key, value)
41+
42+
def has_key(self, key: str) -> bool:
43+
return key in self.cache
44+
45+
def __del__(self):
46+
if hasattr(self, "cache"):
47+
self.cache.close()
48+
49+
50+
def _make_hashable(o):
51+
if isinstance(o, (tuple, list)):
52+
return tuple(_make_hashable(e) for e in o)
53+
elif isinstance(o, dict):
54+
return tuple(sorted((k, _make_hashable(v)) for k, v in o.items()))
55+
elif isinstance(o, set):
56+
return tuple(sorted(_make_hashable(e) for e in o))
57+
elif isinstance(o, BaseModel):
58+
return _make_hashable(o.model_dump())
59+
else:
60+
return o
61+
62+
63+
EXCLUDE_KEYS = ["callbacks"]
64+
65+
66+
def _generate_cache_key(func, args, kwargs):
67+
if inspect.ismethod(func):
68+
args = args[1:]
69+
70+
filtered_kwargs = {k: v for k, v in kwargs.items() if k not in EXCLUDE_KEYS}
71+
72+
key_data = {
73+
"function": func.__qualname__,
74+
"args": _make_hashable(args),
75+
"kwargs": _make_hashable(filtered_kwargs),
76+
}
77+
78+
key_string = json.dumps(key_data, sort_keys=True, default=str)
79+
cache_key = hashlib.sha256(key_string.encode("utf-8")).hexdigest()
80+
return cache_key
81+
82+
83+
def cacher(cache_backend: Optional[CacheInterface] = None):
84+
def decorator(func):
85+
if cache_backend is None:
86+
return func
87+
88+
# hack to make pyright happy
89+
backend: CacheInterface = cache_backend
90+
91+
is_async = inspect.iscoroutinefunction(func)
92+
93+
@functools.wraps(func)
94+
async def async_wrapper(*args, **kwargs):
95+
cache_key = _generate_cache_key(func, args, kwargs)
96+
97+
if backend.has_key(cache_key):
98+
return backend.get(cache_key)
99+
100+
result = await func(*args, **kwargs)
101+
backend.set(cache_key, result)
102+
return result
103+
104+
@functools.wraps(func)
105+
def sync_wrapper(*args, **kwargs):
106+
cache_key = _generate_cache_key(func, args, kwargs)
107+
108+
if backend.has_key(cache_key):
109+
return backend.get(cache_key)
110+
111+
result = func(*args, **kwargs)
112+
backend.set(cache_key, result)
113+
return result
114+
115+
return async_wrapper if is_async else sync_wrapper
116+
117+
return decorator

src/ragas/embeddings/base.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import asyncio
44
import typing as t
5-
from abc import ABC
5+
from abc import ABC, abstractmethod
66
from dataclasses import field
77
from typing import List
88

@@ -12,6 +12,7 @@
1212
from pydantic.dataclasses import dataclass
1313
from pydantic_core import CoreSchema, core_schema
1414

15+
from ragas.cache import CacheInterface, cacher
1516
from ragas.run_config import RunConfig, add_async_retry, add_retry
1617

1718
if t.TYPE_CHECKING:
@@ -35,6 +36,20 @@ class BaseRagasEmbeddings(Embeddings, ABC):
3536
"""
3637

3738
run_config: RunConfig
39+
cache: t.Optional[CacheInterface] = None
40+
41+
def __init__(self, cache: t.Optional[CacheInterface] = None):
42+
super().__init__()
43+
self.cache = cache
44+
if self.cache is not None:
45+
self.embed_query = cacher(cache_backend=self.cache)(self.embed_query)
46+
self.embed_documents = cacher(cache_backend=self.cache)(
47+
self.embed_documents
48+
)
49+
self.aembed_query = cacher(cache_backend=self.cache)(self.aembed_query)
50+
self.aembed_documents = cacher(cache_backend=self.cache)(
51+
self.aembed_documents
52+
)
3853

3954
async def embed_text(self, text: str, is_async=True) -> List[float]:
4055
"""
@@ -61,6 +76,12 @@ async def embed_texts(
6176
)
6277
return await loop.run_in_executor(None, embed_documents_with_retry, texts)
6378

79+
@abstractmethod
80+
async def aembed_query(self, text: str) -> List[float]: ...
81+
82+
@abstractmethod
83+
async def aembed_documents(self, texts: List[str]) -> t.List[t.List[float]]: ...
84+
6485
def set_run_config(self, run_config: RunConfig):
6586
"""
6687
Set the run configuration for the embedding operations.
@@ -85,8 +106,12 @@ class LangchainEmbeddingsWrapper(BaseRagasEmbeddings):
85106
"""
86107

87108
def __init__(
88-
self, embeddings: Embeddings, run_config: t.Optional[RunConfig] = None
109+
self,
110+
embeddings: Embeddings,
111+
run_config: t.Optional[RunConfig] = None,
112+
cache: t.Optional[CacheInterface] = None,
89113
):
114+
super().__init__(cache=cache)
90115
self.embeddings = embeddings
91116
if run_config is None:
92117
run_config = RunConfig()
@@ -189,11 +214,13 @@ class HuggingfaceEmbeddings(BaseRagasEmbeddings):
189214
cache_folder: t.Optional[str] = None
190215
model_kwargs: t.Dict[str, t.Any] = field(default_factory=dict)
191216
encode_kwargs: t.Dict[str, t.Any] = field(default_factory=dict)
217+
cache: t.Optional[CacheInterface] = None
192218

193219
def __post_init__(self):
194220
"""
195221
Initialize the model after the object is created.
196222
"""
223+
super().__init__(cache=self.cache)
197224
try:
198225
import sentence_transformers
199226
from transformers import AutoConfig
@@ -226,6 +253,9 @@ def __post_init__(self):
226253
if "convert_to_tensor" not in self.encode_kwargs:
227254
self.encode_kwargs["convert_to_tensor"] = True
228255

256+
if self.cache is not None:
257+
self.predict = cacher(cache_backend=self.cache)(self.predict)
258+
229259
def embed_query(self, text: str) -> List[float]:
230260
"""
231261
Embed a single query text.
@@ -297,8 +327,12 @@ class LlamaIndexEmbeddingsWrapper(BaseRagasEmbeddings):
297327
"""
298328

299329
def __init__(
300-
self, embeddings: BaseEmbedding, run_config: t.Optional[RunConfig] = None
330+
self,
331+
embeddings: BaseEmbedding,
332+
run_config: t.Optional[RunConfig] = None,
333+
cache: t.Optional[CacheInterface] = None,
301334
):
335+
super().__init__(cache=cache)
302336
self.embeddings = embeddings
303337
if run_config is None:
304338
run_config = RunConfig()

src/ragas/llms/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from langchain_openai.llms import AzureOpenAI, OpenAI
1414
from langchain_openai.llms.base import BaseOpenAI
1515

16+
from ragas.cache import CacheInterface, cacher
1617
from ragas.exceptions import LLMDidNotFinishException
1718
from ragas.integrations.helicone import helicone_config
1819
from ragas.run_config import RunConfig, add_async_retry
@@ -47,6 +48,13 @@ def is_multiple_completion_supported(llm: BaseLanguageModel) -> bool:
4748
class BaseRagasLLM(ABC):
4849
run_config: RunConfig = field(default_factory=RunConfig, repr=False)
4950
multiple_completion_supported: bool = field(default=False, repr=False)
51+
cache: t.Optional[CacheInterface] = field(default=None, repr=False)
52+
53+
def __post_init__(self):
54+
# If a cache_backend is provided, wrap the implementation methods at construction time.
55+
if self.cache is not None:
56+
self.generate_text = cacher(cache_backend=self.cache)(self.generate_text)
57+
self.agenerate_text = cacher(cache_backend=self.cache)(self.agenerate_text)
5058

5159
def set_run_config(self, run_config: RunConfig):
5260
self.run_config = run_config
@@ -124,7 +132,9 @@ def __init__(
124132
langchain_llm: BaseLanguageModel,
125133
run_config: t.Optional[RunConfig] = None,
126134
is_finished_parser: t.Optional[t.Callable[[LLMResult], bool]] = None,
135+
cache: t.Optional[CacheInterface] = None,
127136
):
137+
super().__init__(cache=cache)
128138
self.langchain_llm = langchain_llm
129139
if run_config is None:
130140
run_config = RunConfig()
@@ -273,7 +283,9 @@ def __init__(
273283
self,
274284
llm: BaseLLM,
275285
run_config: t.Optional[RunConfig] = None,
286+
cache: t.Optional[CacheInterface] = None,
276287
):
288+
super().__init__(cache=cache)
277289
self.llm = llm
278290

279291
try:

src/ragas/prompt/pydantic_prompt.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ async def generate_multiple(
175175
If there's an error parsing the output.
176176
"""
177177
callbacks = callbacks or []
178+
178179
processed_data = self.process_input(data)
179180
prompt_rm, prompt_cb = new_group(
180181
name=self.name,

0 commit comments

Comments
 (0)