Skip to content

Commit 54d7f1c

Browse files
authored
fix caching (#658)
1 parent d0fdc6d commit 54d7f1c

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

langchain/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, engine: Engine, cache_schema: Any = FullLLMCache):
6060
"""Initialize by creating all tables."""
6161
self.engine = engine
6262
self.cache_schema = cache_schema
63-
Base.metadata.create_all(self.engine)
63+
self.cache_schema.metadata.create_all(self.engine)
6464

6565
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
6666
"""Look up based on prompt and llm_string."""

tests/unit_tests/llms/test_base.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""Test base LLM functionality."""
2+
from sqlalchemy import Column, Integer, Sequence, String, create_engine
3+
from sqlalchemy.ext.declarative import declarative_base
4+
25
import langchain
3-
from langchain.cache import InMemoryCache
6+
from langchain.cache import InMemoryCache, SQLAlchemyCache
47
from langchain.schema import Generation, LLMResult
58
from tests.unit_tests.llms.fake_llm import FakeLLM
69

@@ -28,3 +31,41 @@ def test_caching() -> None:
2831
llm_output=None,
2932
)
3033
assert output == expected_output
34+
35+
36+
def test_custom_caching() -> None:
37+
"""Test custom_caching behavior."""
38+
Base = declarative_base()
39+
40+
class FulltextLLMCache(Base): # type: ignore
41+
"""Postgres table for fulltext-indexed LLM Cache."""
42+
43+
__tablename__ = "llm_cache_fulltext"
44+
id = Column(Integer, Sequence("cache_id"), primary_key=True)
45+
prompt = Column(String, nullable=False)
46+
llm = Column(String, nullable=False)
47+
idx = Column(Integer)
48+
response = Column(String)
49+
50+
engine = create_engine("sqlite://")
51+
langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache)
52+
llm = FakeLLM()
53+
params = llm._llm_dict()
54+
params["stop"] = None
55+
llm_string = str(sorted([(k, v) for k, v in params.items()]))
56+
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
57+
output = llm.generate(["foo", "bar", "foo"])
58+
expected_cache_output = [Generation(text="foo")]
59+
cache_output = langchain.llm_cache.lookup("bar", llm_string)
60+
assert cache_output == expected_cache_output
61+
langchain.llm_cache = None
62+
expected_generations = [
63+
[Generation(text="fizz")],
64+
[Generation(text="foo")],
65+
[Generation(text="fizz")],
66+
]
67+
expected_output = LLMResult(
68+
expected_generations,
69+
llm_output=None,
70+
)
71+
assert output == expected_output

0 commit comments

Comments
 (0)