1818import uuid
1919
2020from newrelic .api .function_trace import FunctionTrace
21- from newrelic .api .time_trace import get_trace_linking_metadata
21+ from newrelic .api .time_trace import current_trace , get_trace_linking_metadata
2222from newrelic .api .transaction import current_transaction
2323from newrelic .common .object_wrapper import wrap_function_wrapper
2424from newrelic .common .package_version_utils import get_package_version
2525from newrelic .common .signature import bind_args
2626from newrelic .core .config import global_settings
27+ from newrelic .core .context import context_wrapper
2728
2829_logger = logging .getLogger (__name__ )
2930LANGCHAIN_VERSION = get_package_version ("langchain" )
124125}
125126
126127
127- def _create_error_vectorstore_events (transaction , search_id , args , kwargs , linking_metadata ):
128+ def bind_submit (func , * args , ** kwargs ):
129+ return {"func" : func , "args" : args , "kwargs" : kwargs }
130+
131+
132+ def wrap_ContextThreadPoolExecutor_submit (wrapped , instance , args , kwargs ):
133+ trace = current_trace ()
134+ if not trace :
135+ return wrapped (* args , ** kwargs )
136+
137+ # Use hardened function signature bind so we have safety net catchall of args and kwargs.
138+ bound_args = bind_submit (* args , ** kwargs )
139+ bound_args ["func" ] = context_wrapper (bound_args ["func" ], trace = trace , strict = True )
140+ return wrapped (bound_args ["func" ], * bound_args ["args" ], ** bound_args ["kwargs" ])
141+
142+
143+ def _create_error_vectorstore_events (transaction , search_id , args , kwargs , linking_metadata , wrapped ):
128144 settings = transaction .settings if transaction .settings is not None else global_settings ()
129145 span_id = linking_metadata .get ("span.id" )
130146 trace_id = linking_metadata .get ("trace.id" )
131- request_query , request_k = bind_similarity_search (* args , ** kwargs )
147+ bound_args = bind_args (wrapped , args , kwargs )
148+ request_query = bound_args ["query" ]
149+ request_k = bound_args ["k" ]
132150 llm_metadata_dict = _get_llm_metadata (transaction )
133151 vectorstore_error_dict = {
134152 "request.k" : request_k ,
@@ -169,21 +187,17 @@ async def wrap_asimilarity_search(wrapped, instance, args, kwargs):
169187 except Exception as exc :
170188 ft .notice_error (attributes = {"vector_store_id" : search_id })
171189 ft .__exit__ (* sys .exc_info ())
172- _create_error_vectorstore_events (transaction , search_id , args , kwargs , linking_metadata )
190+ _create_error_vectorstore_events (transaction , search_id , args , kwargs , linking_metadata , wrapped )
173191 raise
174192 ft .__exit__ (None , None , None )
175193
176194 if not response :
177195 return response
178196
179- _record_vector_search_success (transaction , linking_metadata , ft , search_id , args , kwargs , response )
197+ _record_vector_search_success (transaction , linking_metadata , ft , search_id , args , kwargs , response , wrapped )
180198 return response
181199
182200
183- def bind_similarity_search (query , k , * args , ** kwargs ):
184- return query , k
185-
186-
187201def wrap_similarity_search (wrapped , instance , args , kwargs ):
188202 transaction = current_transaction ()
189203 if not transaction :
@@ -206,20 +220,22 @@ def wrap_similarity_search(wrapped, instance, args, kwargs):
206220 except Exception as exc :
207221 ft .notice_error (attributes = {"vector_store_id" : search_id })
208222 ft .__exit__ (* sys .exc_info ())
209- _create_error_vectorstore_events (transaction , search_id , args , kwargs , linking_metadata )
223+ _create_error_vectorstore_events (transaction , search_id , args , kwargs , linking_metadata , wrapped )
210224 raise
211225 ft .__exit__ (None , None , None )
212226
213227 if not response :
214228 return response
215229
216- _record_vector_search_success (transaction , linking_metadata , ft , search_id , args , kwargs , response )
230+ _record_vector_search_success (transaction , linking_metadata , ft , search_id , args , kwargs , response , wrapped )
217231 return response
218232
219233
220- def _record_vector_search_success (transaction , linking_metadata , ft , search_id , args , kwargs , response ):
234+ def _record_vector_search_success (transaction , linking_metadata , ft , search_id , args , kwargs , response , wrapped ):
221235 settings = transaction .settings if transaction .settings is not None else global_settings ()
222- request_query , request_k = bind_similarity_search (* args , ** kwargs )
236+ bound_args = bind_args (wrapped , args , kwargs )
237+ request_query = bound_args ["query" ]
238+ request_k = bound_args ["k" ]
223239 duration = ft .duration * 1000
224240 response_number_of_documents = len (response )
225241 llm_metadata_dict = _get_llm_metadata (transaction )
@@ -879,3 +895,8 @@ def instrument_langchain_callbacks_manager(module):
879895 wrap_function_wrapper (module , "CallbackManager.on_chain_start" , wrap_on_chain_start )
880896 if hasattr (getattr (module , "AsyncCallbackManager" ), "on_chain_start" ):
881897 wrap_function_wrapper (module , "AsyncCallbackManager.on_chain_start" , wrap_async_on_chain_start )
898+
899+
900+ def instrument_langchain_core_runnables_config (module ):
901+ if hasattr (module , "ContextThreadPoolExecutor" ):
902+ wrap_function_wrapper (module , "ContextThreadPoolExecutor.submit" , wrap_ContextThreadPoolExecutor_submit )
0 commit comments