Skip to content

Commit 41e5475

Browse files
authored
Merge pull request #1199 from newrelic/fix-ai-bug
Fixes for langchain & openai
2 parents ec59999 + f7719a3 commit 41e5475

File tree

7 files changed

+474
-39
lines changed

7 files changed

+474
-39
lines changed

newrelic/config.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2086,6 +2086,12 @@ def _process_module_builtin_defaults():
20862086
"newrelic.hooks.mlmodel_openai",
20872087
"instrument_openai_resources_chat_completions",
20882088
)
2089+
2090+
_process_module_definition(
2091+
"openai.resources.completions",
2092+
"newrelic.hooks.mlmodel_openai",
2093+
"instrument_openai_resources_chat_completions",
2094+
)
20892095
_process_module_definition(
20902096
"openai._base_client",
20912097
"newrelic.hooks.mlmodel_openai",
@@ -2103,6 +2109,11 @@ def _process_module_builtin_defaults():
21032109
"newrelic.hooks.mlmodel_langchain",
21042110
"instrument_langchain_runnables_chains_base",
21052111
)
2112+
_process_module_definition(
2113+
"langchain_core.runnables.config",
2114+
"newrelic.hooks.mlmodel_langchain",
2115+
"instrument_langchain_core_runnables_config",
2116+
)
21062117
_process_module_definition(
21072118
"langchain.chains.base",
21082119
"newrelic.hooks.mlmodel_langchain",
@@ -2133,6 +2144,11 @@ def _process_module_builtin_defaults():
21332144
"newrelic.hooks.mlmodel_langchain",
21342145
"instrument_langchain_vectorstore_similarity_search",
21352146
)
2147+
_process_module_definition(
2148+
"langchain_community.vectorstores.aerospike",
2149+
"newrelic.hooks.mlmodel_langchain",
2150+
"instrument_langchain_vectorstore_similarity_search",
2151+
)
21362152
_process_module_definition(
21372153
"langchain_community.vectorstores.analyticdb",
21382154
"newrelic.hooks.mlmodel_langchain",
@@ -2148,6 +2164,11 @@ def _process_module_builtin_defaults():
21482164
"newrelic.hooks.mlmodel_langchain",
21492165
"instrument_langchain_vectorstore_similarity_search",
21502166
)
2167+
_process_module_definition(
2168+
"langchain_community.vectorstores.aperturedb",
2169+
"newrelic.hooks.mlmodel_langchain",
2170+
"instrument_langchain_vectorstore_similarity_search",
2171+
)
21512172
_process_module_definition(
21522173
"langchain_community.vectorstores.astradb",
21532174
"newrelic.hooks.mlmodel_langchain",
@@ -2163,6 +2184,11 @@ def _process_module_builtin_defaults():
21632184
"newrelic.hooks.mlmodel_langchain",
21642185
"instrument_langchain_vectorstore_similarity_search",
21652186
)
2187+
_process_module_definition(
2188+
"langchain_community.vectorstores.azure_cosmos_db_no_sql",
2189+
"newrelic.hooks.mlmodel_langchain",
2190+
"instrument_langchain_vectorstore_similarity_search",
2191+
)
21662192
_process_module_definition(
21672193
"langchain_community.vectorstores.azure_cosmos_db",
21682194
"newrelic.hooks.mlmodel_langchain",
@@ -2349,6 +2375,12 @@ def _process_module_builtin_defaults():
23492375
"instrument_langchain_vectorstore_similarity_search",
23502376
)
23512377

2378+
_process_module_definition(
2379+
"langchain_community.vectorstores.manticore_search",
2380+
"newrelic.hooks.mlmodel_langchain",
2381+
"instrument_langchain_vectorstore_similarity_search",
2382+
)
2383+
23522384
_process_module_definition(
23532385
"langchain_community.vectorstores.marqo",
23542386
"newrelic.hooks.mlmodel_langchain",
@@ -2397,6 +2429,12 @@ def _process_module_builtin_defaults():
23972429
"instrument_langchain_vectorstore_similarity_search",
23982430
)
23992431

2432+
_process_module_definition(
2433+
"langchain_community.vectorstores.thirdai_neuraldb",
2434+
"newrelic.hooks.mlmodel_langchain",
2435+
"instrument_langchain_vectorstore_similarity_search",
2436+
)
2437+
24002438
_process_module_definition(
24012439
"langchain_community.vectorstores.nucliadb",
24022440
"newrelic.hooks.mlmodel_langchain",
@@ -2625,6 +2663,12 @@ def _process_module_builtin_defaults():
26252663
"instrument_langchain_vectorstore_similarity_search",
26262664
)
26272665

2666+
_process_module_definition(
2667+
"langchain_community.vectorstores.zep_cloud",
2668+
"newrelic.hooks.mlmodel_langchain",
2669+
"instrument_langchain_vectorstore_similarity_search",
2670+
)
2671+
26282672
_process_module_definition(
26292673
"langchain_community.vectorstores.zep",
26302674
"newrelic.hooks.mlmodel_langchain",

newrelic/hooks/mlmodel_langchain.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,29 @@
1818
import uuid
1919

2020
from newrelic.api.function_trace import FunctionTrace
21-
from newrelic.api.time_trace import get_trace_linking_metadata
21+
from newrelic.api.time_trace import current_trace, get_trace_linking_metadata
2222
from newrelic.api.transaction import current_transaction
2323
from newrelic.common.object_wrapper import wrap_function_wrapper
2424
from newrelic.common.package_version_utils import get_package_version
2525
from newrelic.common.signature import bind_args
2626
from newrelic.core.config import global_settings
27+
from newrelic.core.context import context_wrapper
2728

2829
_logger = logging.getLogger(__name__)
2930
LANGCHAIN_VERSION = get_package_version("langchain")
3031
EXCEPTION_HANDLING_FAILURE_LOG_MESSAGE = "Exception occurred in langchain instrumentation: While reporting an exception in langchain, another exception occurred. Report this issue to New Relic Support.\n%s"
3132
RECORD_EVENTS_FAILURE_LOG_MESSAGE = "Exception occurred in langchain instrumentation: Failed to record LLM events. Report this issue to New Relic Support.\n%s"
32-
3333
VECTORSTORE_CLASSES = {
34+
"langchain_community.vectorstores.aerospike": "Aerospike",
3435
"langchain_community.vectorstores.alibabacloud_opensearch": "AlibabaCloudOpenSearch",
3536
"langchain_community.vectorstores.analyticdb": "AnalyticDB",
3637
"langchain_community.vectorstores.annoy": "Annoy",
3738
"langchain_community.vectorstores.apache_doris": "ApacheDoris",
39+
"langchain_community.vectorstores.aperturedb": "ApertureDB",
3840
"langchain_community.vectorstores.astradb": "AstraDB",
3941
"langchain_community.vectorstores.atlas": "AtlasDB",
4042
"langchain_community.vectorstores.awadb": "AwaDB",
43+
"langchain_community.vectorstores.azure_cosmos_db_no_sql": "AzureCosmosDBNoSqlVectorSearch",
4144
"langchain_community.vectorstores.azure_cosmos_db": "AzureCosmosDBVectorSearch",
4245
"langchain_community.vectorstores.azuresearch": "AzureSearch",
4346
"langchain_community.vectorstores.baiduvectordb": "BaiduVectorDB",
@@ -71,6 +74,7 @@
7174
"langchain_community.vectorstores.lancedb": "LanceDB",
7275
"langchain_community.vectorstores.lantern": "Lantern",
7376
"langchain_community.vectorstores.llm_rails": "LLMRails",
77+
"langchain_community.vectorstores.manticore_search": "ManticoreSearch",
7478
"langchain_community.vectorstores.marqo": "Marqo",
7579
"langchain_community.vectorstores.matching_engine": "MatchingEngine",
7680
"langchain_community.vectorstores.meilisearch": "Meilisearch",
@@ -79,7 +83,7 @@
7983
"langchain_community.vectorstores.mongodb_atlas": "MongoDBAtlasVectorSearch",
8084
"langchain_community.vectorstores.myscale": "MyScale",
8185
"langchain_community.vectorstores.neo4j_vector": "Neo4jVector",
82-
"langchain_community.vectorstores.thirdai_neuraldb": "NeuralDBVectorStore",
86+
"langchain_community.vectorstores.thirdai_neuraldb": ["NeuralDBClientVectorStore", "NeuralDBVectorStore"],
8387
"langchain_community.vectorstores.nucliadb": "NucliaDB",
8488
"langchain_community.vectorstores.oraclevs": "OracleVS",
8589
"langchain_community.vectorstores.opensearch_vector_search": "OpenSearchVectorSearch",
@@ -118,17 +122,35 @@
118122
"langchain_community.vectorstores.weaviate": "Weaviate",
119123
"langchain_community.vectorstores.xata": "XataVectorStore",
120124
"langchain_community.vectorstores.yellowbrick": "Yellowbrick",
125+
"langchain_community.vectorstores.zep_cloud": "ZepCloudVectorStore",
121126
"langchain_community.vectorstores.zep": "ZepVectorStore",
122127
"langchain_community.vectorstores.docarray.hnsw": "DocArrayHnswSearch",
123128
"langchain_community.vectorstores.docarray.in_memory": "DocArrayInMemorySearch",
124129
}
125130

126131

127-
def _create_error_vectorstore_events(transaction, search_id, args, kwargs, linking_metadata):
132+
def bind_submit(func, *args, **kwargs):
133+
return {"func": func, "args": args, "kwargs": kwargs}
134+
135+
136+
def wrap_ContextThreadPoolExecutor_submit(wrapped, instance, args, kwargs):
137+
trace = current_trace()
138+
if not trace:
139+
return wrapped(*args, **kwargs)
140+
141+
# Use hardened function signature bind so we have safety net catchall of args and kwargs.
142+
bound_args = bind_submit(*args, **kwargs)
143+
bound_args["func"] = context_wrapper(bound_args["func"], trace=trace, strict=True)
144+
return wrapped(bound_args["func"], *bound_args["args"], **bound_args["kwargs"])
145+
146+
147+
def _create_error_vectorstore_events(transaction, search_id, args, kwargs, linking_metadata, wrapped):
128148
settings = transaction.settings if transaction.settings is not None else global_settings()
129149
span_id = linking_metadata.get("span.id")
130150
trace_id = linking_metadata.get("trace.id")
131-
request_query, request_k = bind_similarity_search(*args, **kwargs)
151+
bound_args = bind_args(wrapped, args, kwargs)
152+
request_query = bound_args["query"]
153+
request_k = bound_args["k"]
132154
llm_metadata_dict = _get_llm_metadata(transaction)
133155
vectorstore_error_dict = {
134156
"request.k": request_k,
@@ -169,21 +191,17 @@ async def wrap_asimilarity_search(wrapped, instance, args, kwargs):
169191
except Exception as exc:
170192
ft.notice_error(attributes={"vector_store_id": search_id})
171193
ft.__exit__(*sys.exc_info())
172-
_create_error_vectorstore_events(transaction, search_id, args, kwargs, linking_metadata)
194+
_create_error_vectorstore_events(transaction, search_id, args, kwargs, linking_metadata, wrapped)
173195
raise
174196
ft.__exit__(None, None, None)
175197

176198
if not response:
177199
return response
178200

179-
_record_vector_search_success(transaction, linking_metadata, ft, search_id, args, kwargs, response)
201+
_record_vector_search_success(transaction, linking_metadata, ft, search_id, args, kwargs, response, wrapped)
180202
return response
181203

182204

183-
def bind_similarity_search(query, k, *args, **kwargs):
184-
return query, k
185-
186-
187205
def wrap_similarity_search(wrapped, instance, args, kwargs):
188206
transaction = current_transaction()
189207
if not transaction:
@@ -206,20 +224,22 @@ def wrap_similarity_search(wrapped, instance, args, kwargs):
206224
except Exception as exc:
207225
ft.notice_error(attributes={"vector_store_id": search_id})
208226
ft.__exit__(*sys.exc_info())
209-
_create_error_vectorstore_events(transaction, search_id, args, kwargs, linking_metadata)
227+
_create_error_vectorstore_events(transaction, search_id, args, kwargs, linking_metadata, wrapped)
210228
raise
211229
ft.__exit__(None, None, None)
212230

213231
if not response:
214232
return response
215233

216-
_record_vector_search_success(transaction, linking_metadata, ft, search_id, args, kwargs, response)
234+
_record_vector_search_success(transaction, linking_metadata, ft, search_id, args, kwargs, response, wrapped)
217235
return response
218236

219237

220-
def _record_vector_search_success(transaction, linking_metadata, ft, search_id, args, kwargs, response):
238+
def _record_vector_search_success(transaction, linking_metadata, ft, search_id, args, kwargs, response, wrapped):
221239
settings = transaction.settings if transaction.settings is not None else global_settings()
222-
request_query, request_k = bind_similarity_search(*args, **kwargs)
240+
bound_args = bind_args(wrapped, args, kwargs)
241+
request_query = bound_args["query"]
242+
request_k = bound_args["k"]
223243
duration = ft.duration * 1000
224244
response_number_of_documents = len(response)
225245
llm_metadata_dict = _get_llm_metadata(transaction)
@@ -855,12 +875,20 @@ def instrument_langchain_chains_base(module):
855875

856876

857877
def instrument_langchain_vectorstore_similarity_search(module):
858-
vector_class = VECTORSTORE_CLASSES.get(module.__name__)
859-
860-
if vector_class and hasattr(getattr(module, vector_class, ""), "similarity_search"):
861-
wrap_function_wrapper(module, "%s.similarity_search" % vector_class, wrap_similarity_search)
862-
if vector_class and hasattr(getattr(module, vector_class, ""), "asimilarity_search"):
863-
wrap_function_wrapper(module, "%s.asimilarity_search" % vector_class, wrap_asimilarity_search)
878+
def _instrument_class(module, vector_class):
879+
if hasattr(getattr(module, vector_class, ""), "similarity_search"):
880+
wrap_function_wrapper(module, "%s.similarity_search" % vector_class, wrap_similarity_search)
881+
if hasattr(getattr(module, vector_class, ""), "asimilarity_search"):
882+
wrap_function_wrapper(module, "%s.asimilarity_search" % vector_class, wrap_asimilarity_search)
883+
884+
vector_classes = VECTORSTORE_CLASSES.get(module.__name__)
885+
if vector_classes is None:
886+
return
887+
if isinstance(vector_classes, list):
888+
for vector_class in vector_classes:
889+
_instrument_class(module, vector_class)
890+
else:
891+
_instrument_class(module, vector_classes)
864892

865893

866894
def instrument_langchain_core_tools(module):
@@ -879,3 +907,8 @@ def instrument_langchain_callbacks_manager(module):
879907
wrap_function_wrapper(module, "CallbackManager.on_chain_start", wrap_on_chain_start)
880908
if hasattr(getattr(module, "AsyncCallbackManager"), "on_chain_start"):
881909
wrap_function_wrapper(module, "AsyncCallbackManager.on_chain_start", wrap_async_on_chain_start)
910+
911+
912+
def instrument_langchain_core_runnables_config(module):
913+
if hasattr(module, "ContextThreadPoolExecutor"):
914+
wrap_function_wrapper(module, "ContextThreadPoolExecutor.submit", wrap_ContextThreadPoolExecutor_submit)

newrelic/hooks/mlmodel_openai.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,9 @@ def _record_completion_success(transaction, linking_metadata, completion_id, kwa
492492
finish_reason = None
493493
choices = response.get("choices") or []
494494
if choices:
495-
output_message_list = [choices[0].get("message")]
495+
output_message_list = [
496+
choices[0].get("message") or {"content": choices[0].get("text"), "role": "assistant"}
497+
]
496498
finish_reason = choices[0].get("finish_reason")
497499
else:
498500
response_model = kwargs.get("response.model")
@@ -507,7 +509,7 @@ def _record_completion_success(transaction, linking_metadata, completion_id, kwa
507509

508510
request_id = response_headers.get("x-request-id")
509511
organization = response_headers.get("openai-organization") or getattr(response, "organization", None)
510-
messages = kwargs.get("messages", None) or []
512+
messages = kwargs.get("messages") or [{"content": kwargs.get("prompt"), "role": "user"}]
511513
input_message_list = list(messages)
512514
full_chat_completion_summary_dict = {
513515
"id": completion_id,

0 commit comments

Comments
 (0)