Skip to content

Commit d10fd02

Browse files
authored
langchain[patch]: Allow specifying other hashing functions in embeddings (#31561)
Allow specifying other hashing functions in embeddings
1 parent 4071670 commit d10fd02

File tree

2 files changed

+190
-14
lines changed

2 files changed

+190
-14
lines changed

libs/langchain/langchain/embeddings/cache.py

Lines changed: 100 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
import hashlib
1313
import json
1414
import uuid
15+
import warnings
1516
from collections.abc import Sequence
16-
from functools import partial
17-
from typing import Callable, Optional, Union, cast
17+
from typing import Callable, Literal, Optional, Union, cast
1818

1919
from langchain_core.embeddings import Embeddings
2020
from langchain_core.stores import BaseStore, ByteStore
@@ -25,20 +25,51 @@
2525
NAMESPACE_UUID = uuid.UUID(int=1985)
2626

2727

28-
def _hash_string_to_uuid(input_string: str) -> uuid.UUID:
29-
"""Hash a string and returns the corresponding UUID."""
30-
hash_value = hashlib.sha1(input_string.encode("utf-8")).hexdigest()
31-
return uuid.uuid5(NAMESPACE_UUID, hash_value)
28+
def _sha1_hash_to_uuid(text: str) -> uuid.UUID:
29+
"""Return a UUID derived from *text* using SHA‑1 (deterministic).
3230
31+
Deterministic and fast, **but not collision‑resistant**.
3332
34-
def _key_encoder(key: str, namespace: str) -> str:
35-
"""Encode a key."""
36-
return namespace + str(_hash_string_to_uuid(key))
33+
A malicious attacker could try to create two different texts that hash to the same
34+
UUID. This may not necessarily be an issue in the context of caching embeddings,
35+
but new applications should swap this out for a stronger hash function like
36+
xxHash, BLAKE2 or SHA‑256, which are collision-resistant.
37+
"""
38+
sha1_hex = hashlib.sha1(text.encode("utf-8")).hexdigest()
39+
# Embed the hex string in `uuid5` to obtain a valid UUID.
40+
return uuid.uuid5(NAMESPACE_UUID, sha1_hex)
41+
42+
43+
def _make_default_key_encoder(namespace: str, algorithm: str) -> Callable[[str], str]:
44+
"""Create a default key encoder function.
45+
46+
Args:
47+
namespace: Prefix that segregates keys from different embedding models.
48+
algorithm:
49+
* `sha1` - fast but not collision‑resistant
50+
* `blake2b` - cryptographically strong, faster than SHA‑1
51+
* `sha256` - cryptographically strong, slower than SHA‑1
52+
* `sha512` - cryptographically strong, slower than SHA‑1
3753
54+
Returns:
55+
A function that encodes a key using the specified algorithm.
56+
"""
57+
if algorithm == "sha1":
58+
_warn_about_sha1_encoder()
59+
60+
def _key_encoder(key: str) -> str:
61+
"""Encode a key using the specified algorithm."""
62+
if algorithm == "sha1":
63+
return f"{namespace}{_sha1_hash_to_uuid(key)}"
64+
if algorithm == "blake2b":
65+
return f"{namespace}{hashlib.blake2b(key.encode('utf-8')).hexdigest()}"
66+
if algorithm == "sha256":
67+
return f"{namespace}{hashlib.sha256(key.encode('utf-8')).hexdigest()}"
68+
if algorithm == "sha512":
69+
return f"{namespace}{hashlib.sha512(key.encode('utf-8')).hexdigest()}"
70+
raise ValueError(f"Unsupported algorithm: {algorithm}")
3871

39-
def _create_key_encoder(namespace: str) -> Callable[[str], str]:
40-
"""Create an encoder for a key."""
41-
return partial(_key_encoder, namespace=namespace)
72+
return _key_encoder
4273

4374

4475
def _value_serializer(value: Sequence[float]) -> bytes:
@@ -51,6 +82,28 @@ def _value_deserializer(serialized_value: bytes) -> list[float]:
5182
return cast(list[float], json.loads(serialized_value.decode()))
5283

5384

85+
# The warning is global; track emission, so it appears only once.
86+
_warned_about_sha1: bool = False
87+
88+
89+
def _warn_about_sha1_encoder() -> None:
90+
"""Emit a one‑time warning about SHA‑1 collision weaknesses."""
91+
global _warned_about_sha1
92+
if not _warned_about_sha1:
93+
warnings.warn(
94+
"Using default key encoder: SHA‑1 is *not* collision‑resistant. "
95+
"While acceptable for most cache scenarios, a motivated attacker "
96+
"can craft two different payloads that map to the same cache key. "
97+
"If that risk matters in your environment, supply a stronger "
98+
"encoder (e.g. SHA‑256 or BLAKE2) via the `key_encoder` argument. "
99+
"If you change the key encoder, consider also creating a new cache, "
100+
"to avoid (the potential for) collisions with existing keys.",
101+
category=UserWarning,
102+
stacklevel=2,
103+
)
104+
_warned_about_sha1 = True
105+
106+
54107
class CacheBackedEmbeddings(Embeddings):
55108
"""Interface for caching results from embedding models.
56109
@@ -234,6 +287,9 @@ def from_bytes_store(
234287
namespace: str = "",
235288
batch_size: Optional[int] = None,
236289
query_embedding_cache: Union[bool, ByteStore] = False,
290+
key_encoder: Union[
291+
Callable[[str], str], Literal["sha1", "blake2b", "sha256", "sha512"]
292+
] = "sha1",
237293
) -> CacheBackedEmbeddings:
238294
"""On-ramp that adds the necessary serialization and encoding to the store.
239295
@@ -248,9 +304,39 @@ def from_bytes_store(
248304
query_embedding_cache: The cache to use for storing query embeddings.
249305
True to use the same cache as document embeddings.
250306
False to not cache query embeddings.
307+
key_encoder: Optional callable to encode keys. If not provided,
308+
a default encoder using SHA‑1 will be used. SHA-1 is not
309+
collision-resistant, and a motivated attacker could craft two
310+
different texts that hash to the same cache key.
311+
312+
New applications should use one of the alternative encoders
313+
or provide a custom and strong key encoder function to avoid this risk.
314+
315+
If you change a key encoder in an existing cache, consider
316+
just creating a new cache, to avoid (the potential for)
317+
collisions with existing keys or having duplicate keys
318+
for the same text in the cache.
319+
320+
Returns:
321+
An instance of CacheBackedEmbeddings that uses the provided cache.
251322
"""
252-
namespace = namespace
253-
key_encoder = _create_key_encoder(namespace)
323+
if isinstance(key_encoder, str):
324+
key_encoder = _make_default_key_encoder(namespace, key_encoder)
325+
elif callable(key_encoder):
326+
# If a custom key encoder is provided, it should not be used with a
327+
# namespace.
328+
# A user can handle namespacing in directly their custom key encoder.
329+
if namespace:
330+
raise ValueError(
331+
"Do not supply `namespace` when using a custom key_encoder; "
332+
"add any prefixing inside the encoder itself."
333+
)
334+
else:
335+
raise ValueError(
336+
"key_encoder must be either 'blake2b', 'sha1', 'sha256', 'sha512' "
337+
"or a callable that encodes keys."
338+
)
339+
254340
document_embedding_store = EncoderBackedStore[str, list[float]](
255341
document_embedding_cache,
256342
key_encoder,

libs/langchain/tests/unit_tests/embeddings/test_caching.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
"""Embeddings tests."""
22

3+
import hashlib
4+
import importlib
5+
import warnings
6+
37
import pytest
48
from langchain_core.embeddings import Embeddings
59

@@ -146,3 +150,89 @@ async def test_aembed_query_cached(
146150
keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr]
147151
assert len(keys) == 1
148152
assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"
153+
154+
155+
def test_blake2b_encoder() -> None:
156+
"""Test that the blake2b encoder is used to encode keys in the cache store."""
157+
store = InMemoryStore()
158+
emb = MockEmbeddings()
159+
cbe = CacheBackedEmbeddings.from_bytes_store(
160+
emb, store, namespace="ns_", key_encoder="blake2b"
161+
)
162+
163+
text = "blake"
164+
cbe.embed_documents([text])
165+
166+
# rebuild the key exactly as the library does
167+
expected_key = "ns_" + hashlib.blake2b(text.encode()).hexdigest()
168+
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
169+
170+
171+
def test_sha256_encoder() -> None:
172+
"""Test that the sha256 encoder is used to encode keys in the cache store."""
173+
store = InMemoryStore()
174+
emb = MockEmbeddings()
175+
cbe = CacheBackedEmbeddings.from_bytes_store(
176+
emb, store, namespace="ns_", key_encoder="sha256"
177+
)
178+
179+
text = "foo"
180+
cbe.embed_documents([text])
181+
182+
# rebuild the key exactly as the library does
183+
expected_key = "ns_" + hashlib.sha256(text.encode()).hexdigest()
184+
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
185+
186+
187+
def test_sha512_encoder() -> None:
188+
"""Test that the sha512 encoder is used to encode keys in the cache store."""
189+
store = InMemoryStore()
190+
emb = MockEmbeddings()
191+
cbe = CacheBackedEmbeddings.from_bytes_store(
192+
emb, store, namespace="ns_", key_encoder="sha512"
193+
)
194+
195+
text = "foo"
196+
cbe.embed_documents([text])
197+
198+
# rebuild the key exactly as the library does
199+
expected_key = "ns_" + hashlib.sha512(text.encode()).hexdigest()
200+
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
201+
202+
203+
def test_sha1_warning_emitted_once() -> None:
204+
"""Test that a warning is emitted when using SHA‑1 as the default key encoder."""
205+
module = importlib.import_module(CacheBackedEmbeddings.__module__)
206+
207+
# Create a *temporary* MonkeyPatch object whose effects disappear
208+
# automatically when the with‑block exits.
209+
with pytest.MonkeyPatch.context() as mp:
210+
# We're monkey patching the module to reset the `_warned_about_sha1` flag
211+
# which may have been set while testing other parts of the codebase.
212+
mp.setattr(module, "_warned_about_sha1", False, raising=False)
213+
214+
store = InMemoryStore()
215+
emb = MockEmbeddings()
216+
217+
with warnings.catch_warnings(record=True) as caught:
218+
warnings.simplefilter("always")
219+
CacheBackedEmbeddings.from_bytes_store(emb, store) # triggers warning
220+
CacheBackedEmbeddings.from_bytes_store(emb, store) # silent
221+
222+
sha1_msgs = [w for w in caught if "SHA‑1" in str(w.message)]
223+
assert len(sha1_msgs) == 1
224+
225+
226+
def test_custom_encoder() -> None:
227+
"""Test that a custom encoder can be used to encode keys in the cache store."""
228+
store = InMemoryStore()
229+
emb = MockEmbeddings()
230+
231+
def custom_upper(text: str) -> str: # very simple demo encoder
232+
return "CUSTOM_" + text.upper()
233+
234+
cbe = CacheBackedEmbeddings.from_bytes_store(emb, store, key_encoder=custom_upper)
235+
txt = "x"
236+
cbe.embed_documents([txt])
237+
238+
assert list(cbe.document_embedding_store.yield_keys()) == ["CUSTOM_X"]

0 commit comments

Comments
 (0)