Skip to content

Commit c70ef91

Browse files
fix(langchain): add patch test cases, fix double patch issue for embeddings/vectorstores [backport 1.17] (#6512)
Backport 9ec214f from #6475 to 1.17. This PR adds patch test cases for the langchain integration, and fixes a bug where we were double patching `langchain.embeddings.HuggingFaceEmbeddings.embed_query/documents` and `langchain.vectorstores.Milvus.similarity_search`. This issue stemmed from the two classes being reused/inherited by other classes (`SentenceTransformerEmbeddings`, `Zilliz` respectively), meaning that when we wrapped the latter two classes' methods, we unintentionally wrapped the former two classes' methods twice. The fix involves checking that the function to wrap isn't already a wrapped method (specifically a `wrapt.ObjectProxy` which we use in our integrations). This could potentially be problematic if a user wraps this before running our patching code as we would ignore this, but at the moment of writing I don't see a huge risk nor a cleaner solution to avoid double patching. ## Checklist - [x] Change(s) are motivated and described in the PR description. - [x] Testing strategy is described if automated tests are not included in the PR. - [x] Risk is outlined (performance impact, potential for breakage, maintainability, etc). - [x] Change is maintainable (easy to change, telemetry, documentation). - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed. If no release note is required, add label `changelog/no-changelog`. - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)). - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) ## Reviewer Checklist - [x] Title is accurate. - [x] No unnecessary changes are introduced. - [x] Description motivates each change. - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes unless absolutely necessary. - [x] Testing strategy adequately addresses listed risk(s). - [x] Change is maintainable (easy to change, telemetry, documentation). - [x] Release note makes sense to a user of the library. - [x] Reviewer has explicitly acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment. - [x] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) Co-authored-by: Yun Kim <[email protected]>
1 parent 09b4699 commit c70ef91

File tree

3 files changed

+103
-11
lines changed

3 files changed

+103
-11
lines changed

ddtrace/contrib/langchain/patch.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ddtrace.internal.utils.formats import asbool
3131
from ddtrace.internal.utils.formats import deep_getattr
3232
from ddtrace.pin import Pin
33+
from ddtrace.vendor import wrapt
3334

3435

3536
if TYPE_CHECKING:
@@ -750,9 +751,8 @@ def patch():
750751
)
751752
integration.start_log_writer()
752753

753-
# TODO: check if we need to version gate LLM/Chat/TextEmbedding
754754
wrap("langchain", "llms.base.BaseLLM.generate", traced_llm_generate(langchain))
755-
wrap("langchain", "llms.BaseLLM.agenerate", traced_llm_agenerate(langchain))
755+
wrap("langchain", "llms.base.BaseLLM.agenerate", traced_llm_agenerate(langchain))
756756
wrap("langchain", "chat_models.base.BaseChatModel.generate", traced_chat_model_generate(langchain))
757757
wrap("langchain", "chat_models.base.BaseChatModel.agenerate", traced_chat_model_agenerate(langchain))
758758
wrap("langchain", "chains.base.Chain.__call__", traced_chain_call(langchain))
@@ -761,18 +761,32 @@ def patch():
761761
# wrap each langchain-provided text embedding model.
762762
for text_embedding_model in text_embedding_models:
763763
if hasattr(langchain.embeddings, text_embedding_model):
764-
wrap("langchain", "embeddings.%s.embed_query" % text_embedding_model, traced_embedding(langchain))
765-
wrap("langchain", "embeddings.%s.embed_documents" % text_embedding_model, traced_embedding(langchain))
766-
# TODO: langchain >= 0.0.209 includes async embedding implementation (only for OpenAI)
764+
# Ensure not double patched, as some Embeddings interfaces are pointers to other Embeddings.
765+
if not isinstance(
766+
deep_getattr(langchain.embeddings, "%s.embed_query" % text_embedding_model), wrapt.ObjectProxy
767+
):
768+
wrap("langchain", "embeddings.%s.embed_query" % text_embedding_model, traced_embedding(langchain))
769+
if not isinstance(
770+
deep_getattr(langchain.embeddings, "%s.embed_documents" % text_embedding_model), wrapt.ObjectProxy
771+
):
772+
wrap("langchain", "embeddings.%s.embed_documents" % text_embedding_model, traced_embedding(langchain))
773+
# TODO: langchain >= 0.0.209 includes async embedding implementation (only for OpenAI)
767774
# We need to do the same with Vectorstores.
768775
for vectorstore in vectorstores:
769776
if hasattr(langchain.vectorstores, vectorstore):
770-
wrap("langchain", "vectorstores.%s.similarity_search" % vectorstore, traced_similarity_search(langchain))
777+
# Ensure not double patched, as some Embeddings interfaces are pointers to other Embeddings.
778+
if not isinstance(
779+
deep_getattr(langchain.vectorstores, "%s.similarity_search" % vectorstore), wrapt.ObjectProxy
780+
):
781+
wrap(
782+
"langchain", "vectorstores.%s.similarity_search" % vectorstore, traced_similarity_search(langchain)
783+
)
771784

772785

773786
def unpatch():
774-
if getattr(langchain, "_datadog_patch", False):
775-
setattr(langchain, "_datadog_patch", False)
787+
if not getattr(langchain, "_datadog_patch", False):
788+
return
789+
setattr(langchain, "_datadog_patch", False)
776790

777791
unwrap(langchain.llms.base.BaseLLM, "generate")
778792
unwrap(langchain.llms.base.BaseLLM, "agenerate")
@@ -782,10 +796,19 @@ def unpatch():
782796
unwrap(langchain.chains.base.Chain, "acall")
783797
for text_embedding_model in text_embedding_models:
784798
if hasattr(langchain.embeddings, text_embedding_model):
785-
unwrap(getattr(langchain.embeddings, text_embedding_model), "embed_query")
786-
unwrap(getattr(langchain.embeddings, text_embedding_model), "embed_documents")
799+
if isinstance(
800+
deep_getattr(langchain.embeddings, "%s.embed_query" % text_embedding_model), wrapt.ObjectProxy
801+
):
802+
unwrap(getattr(langchain.embeddings, text_embedding_model), "embed_query")
803+
if isinstance(
804+
deep_getattr(langchain.embeddings, "%s.embed_documents" % text_embedding_model), wrapt.ObjectProxy
805+
):
806+
unwrap(getattr(langchain.embeddings, text_embedding_model), "embed_documents")
787807
for vectorstore in vectorstores:
788808
if hasattr(langchain.vectorstores, vectorstore):
789-
unwrap(getattr(langchain.vectorstores, vectorstore), "similarity_search")
809+
if isinstance(
810+
deep_getattr(langchain.vectorstores, "%s.similarity_search" % vectorstore), wrapt.ObjectProxy
811+
):
812+
unwrap(getattr(langchain.vectorstores, vectorstore), "similarity_search")
790813

791814
delattr(langchain, "_datadog_integration")
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
fixes:
3+
- |
4+
langchain: This fix resolves an issue where ``langchain.embeddings.HuggingFaceEmbeddings`` embedding
5+
methods, and ``langchain.vectorstores.Milvus.similarity_search`` were patched twice
6+
due to a nested class hierarchy in ``langchain``.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from ddtrace.contrib.langchain import patch
2+
from ddtrace.contrib.langchain import unpatch
3+
from ddtrace.contrib.langchain.constants import text_embedding_models
4+
from ddtrace.contrib.langchain.constants import vectorstores
5+
from tests.contrib.patch import PatchTestCase
6+
7+
8+
class TestLangchainPatch(PatchTestCase.Base):
9+
__integration_name__ = "langchain"
10+
__module_name__ = "langchain"
11+
__patch_func__ = patch
12+
__unpatch_func__ = unpatch
13+
14+
def assert_module_patched(self, langchain):
15+
self.assert_wrapped(langchain.llms.base.BaseLLM.generate)
16+
self.assert_wrapped(langchain.llms.base.BaseLLM.agenerate)
17+
self.assert_wrapped(langchain.chat_models.base.BaseChatModel.generate)
18+
self.assert_wrapped(langchain.chat_models.base.BaseChatModel.agenerate)
19+
self.assert_wrapped(langchain.chains.base.Chain.__call__)
20+
self.assert_wrapped(langchain.chains.base.Chain.acall)
21+
for text_embedding_model in text_embedding_models:
22+
embedding_model = getattr(langchain.embeddings, text_embedding_model, None)
23+
if embedding_model:
24+
self.assert_wrapped(embedding_model.embed_query)
25+
self.assert_wrapped(embedding_model.embed_documents)
26+
for vectorstore in vectorstores:
27+
vectorstore_interface = getattr(langchain.vectorstores, vectorstore, None)
28+
if vectorstore_interface:
29+
self.assert_wrapped(vectorstore_interface.similarity_search)
30+
31+
def assert_not_module_patched(self, langchain):
32+
self.assert_not_wrapped(langchain.llms.base.BaseLLM.generate)
33+
self.assert_not_wrapped(langchain.llms.base.BaseLLM.agenerate)
34+
self.assert_not_wrapped(langchain.chat_models.base.BaseChatModel.generate)
35+
self.assert_not_wrapped(langchain.chat_models.base.BaseChatModel.agenerate)
36+
self.assert_not_wrapped(langchain.chains.base.Chain.__call__)
37+
self.assert_not_wrapped(langchain.chains.base.Chain.acall)
38+
for text_embedding_model in text_embedding_models:
39+
embedding_model = getattr(langchain.embeddings, text_embedding_model, None)
40+
if embedding_model:
41+
self.assert_not_wrapped(embedding_model.embed_query)
42+
self.assert_not_wrapped(embedding_model.embed_documents)
43+
for vectorstore in vectorstores:
44+
vectorstore_interface = getattr(langchain.vectorstores, vectorstore, None)
45+
if vectorstore_interface:
46+
self.assert_not_wrapped(vectorstore_interface.similarity_search)
47+
48+
def assert_not_module_double_patched(self, langchain):
49+
self.assert_not_double_wrapped(langchain.llms.base.BaseLLM.generate)
50+
self.assert_not_double_wrapped(langchain.llms.base.BaseLLM.agenerate)
51+
self.assert_not_double_wrapped(langchain.chat_models.base.BaseChatModel.generate)
52+
self.assert_not_double_wrapped(langchain.chat_models.base.BaseChatModel.agenerate)
53+
self.assert_not_double_wrapped(langchain.chains.base.Chain.__call__)
54+
self.assert_not_double_wrapped(langchain.chains.base.Chain.acall)
55+
for text_embedding_model in text_embedding_models:
56+
embedding_model = getattr(langchain.embeddings, text_embedding_model, None)
57+
if embedding_model:
58+
self.assert_not_double_wrapped(embedding_model.embed_query)
59+
self.assert_not_double_wrapped(embedding_model.embed_documents)
60+
for vectorstore in vectorstores:
61+
vectorstore_interface = getattr(langchain.vectorstores, vectorstore, None)
62+
if vectorstore_interface:
63+
self.assert_not_double_wrapped(vectorstore_interface.similarity_search)

0 commit comments

Comments
 (0)