Skip to content

Commit a0a9dd7

Browse files
committed
Fix bugs
1 parent ec59999 commit a0a9dd7

File tree

5 files changed

+414
-23
lines changed

5 files changed

+414
-23
lines changed

newrelic/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2086,6 +2086,12 @@ def _process_module_builtin_defaults():
20862086
"newrelic.hooks.mlmodel_openai",
20872087
"instrument_openai_resources_chat_completions",
20882088
)
2089+
2090+
_process_module_definition(
2091+
"openai.resources.completions",
2092+
"newrelic.hooks.mlmodel_openai",
2093+
"instrument_openai_resources_chat_completions",
2094+
)
20892095
_process_module_definition(
20902096
"openai._base_client",
20912097
"newrelic.hooks.mlmodel_openai",
@@ -2103,6 +2109,11 @@ def _process_module_builtin_defaults():
21032109
"newrelic.hooks.mlmodel_langchain",
21042110
"instrument_langchain_runnables_chains_base",
21052111
)
2112+
_process_module_definition(
2113+
"langchain_core.runnables.config",
2114+
"newrelic.hooks.mlmodel_langchain",
2115+
"instrument_langchain_core_runnables_config",
2116+
)
21062117
_process_module_definition(
21072118
"langchain.chains.base",
21082119
"newrelic.hooks.mlmodel_langchain",

newrelic/hooks/mlmodel_langchain.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
import uuid
1919

2020
from newrelic.api.function_trace import FunctionTrace
21-
from newrelic.api.time_trace import get_trace_linking_metadata
21+
from newrelic.api.time_trace import current_trace, get_trace_linking_metadata
2222
from newrelic.api.transaction import current_transaction
2323
from newrelic.common.object_wrapper import wrap_function_wrapper
2424
from newrelic.common.package_version_utils import get_package_version
2525
from newrelic.common.signature import bind_args
2626
from newrelic.core.config import global_settings
27+
from newrelic.core.context import context_wrapper
2728

2829
_logger = logging.getLogger(__name__)
2930
LANGCHAIN_VERSION = get_package_version("langchain")
@@ -124,11 +125,28 @@
124125
}
125126

126127

127-
def _create_error_vectorstore_events(transaction, search_id, args, kwargs, linking_metadata):
128+
def bind_submit(func, *args, **kwargs):
129+
return {"func": func, "args": args, "kwargs": kwargs}
130+
131+
132+
def wrap_ContextThreadPoolExecutor_submit(wrapped, instance, args, kwargs):
133+
trace = current_trace()
134+
if not trace:
135+
return wrapped(*args, **kwargs)
136+
137+
# Use hardened function signature bind so we have safety net catchall of args and kwargs.
138+
bound_args = bind_submit(*args, **kwargs)
139+
bound_args["func"] = context_wrapper(bound_args["func"], trace=trace, strict=True)
140+
return wrapped(bound_args["func"], *bound_args["args"], **bound_args["kwargs"])
141+
142+
143+
def _create_error_vectorstore_events(transaction, search_id, args, kwargs, linking_metadata, wrapped):
128144
settings = transaction.settings if transaction.settings is not None else global_settings()
129145
span_id = linking_metadata.get("span.id")
130146
trace_id = linking_metadata.get("trace.id")
131-
request_query, request_k = bind_similarity_search(*args, **kwargs)
147+
bound_args = bind_args(wrapped, args, kwargs)
148+
request_query = bound_args["query"]
149+
request_k = bound_args["k"]
132150
llm_metadata_dict = _get_llm_metadata(transaction)
133151
vectorstore_error_dict = {
134152
"request.k": request_k,
@@ -169,21 +187,17 @@ async def wrap_asimilarity_search(wrapped, instance, args, kwargs):
169187
except Exception as exc:
170188
ft.notice_error(attributes={"vector_store_id": search_id})
171189
ft.__exit__(*sys.exc_info())
172-
_create_error_vectorstore_events(transaction, search_id, args, kwargs, linking_metadata)
190+
_create_error_vectorstore_events(transaction, search_id, args, kwargs, linking_metadata, wrapped)
173191
raise
174192
ft.__exit__(None, None, None)
175193

176194
if not response:
177195
return response
178196

179-
_record_vector_search_success(transaction, linking_metadata, ft, search_id, args, kwargs, response)
197+
_record_vector_search_success(transaction, linking_metadata, ft, search_id, args, kwargs, response, wrapped)
180198
return response
181199

182200

183-
def bind_similarity_search(query, k, *args, **kwargs):
184-
return query, k
185-
186-
187201
def wrap_similarity_search(wrapped, instance, args, kwargs):
188202
transaction = current_transaction()
189203
if not transaction:
@@ -206,20 +220,22 @@ def wrap_similarity_search(wrapped, instance, args, kwargs):
206220
except Exception as exc:
207221
ft.notice_error(attributes={"vector_store_id": search_id})
208222
ft.__exit__(*sys.exc_info())
209-
_create_error_vectorstore_events(transaction, search_id, args, kwargs, linking_metadata)
223+
_create_error_vectorstore_events(transaction, search_id, args, kwargs, linking_metadata, wrapped)
210224
raise
211225
ft.__exit__(None, None, None)
212226

213227
if not response:
214228
return response
215229

216-
_record_vector_search_success(transaction, linking_metadata, ft, search_id, args, kwargs, response)
230+
_record_vector_search_success(transaction, linking_metadata, ft, search_id, args, kwargs, response, wrapped)
217231
return response
218232

219233

220-
def _record_vector_search_success(transaction, linking_metadata, ft, search_id, args, kwargs, response):
234+
def _record_vector_search_success(transaction, linking_metadata, ft, search_id, args, kwargs, response, wrapped):
221235
settings = transaction.settings if transaction.settings is not None else global_settings()
222-
request_query, request_k = bind_similarity_search(*args, **kwargs)
236+
bound_args = bind_args(wrapped, args, kwargs)
237+
request_query = bound_args["query"]
238+
request_k = bound_args["k"]
223239
duration = ft.duration * 1000
224240
response_number_of_documents = len(response)
225241
llm_metadata_dict = _get_llm_metadata(transaction)
@@ -879,3 +895,8 @@ def instrument_langchain_callbacks_manager(module):
879895
wrap_function_wrapper(module, "CallbackManager.on_chain_start", wrap_on_chain_start)
880896
if hasattr(getattr(module, "AsyncCallbackManager"), "on_chain_start"):
881897
wrap_function_wrapper(module, "AsyncCallbackManager.on_chain_start", wrap_async_on_chain_start)
898+
899+
900+
def instrument_langchain_core_runnables_config(module):
901+
if hasattr(module, "ContextThreadPoolExecutor"):
902+
wrap_function_wrapper(module, "ContextThreadPoolExecutor.submit", wrap_ContextThreadPoolExecutor_submit)

newrelic/hooks/mlmodel_openai.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,9 @@ def _record_completion_success(transaction, linking_metadata, completion_id, kwa
492492
finish_reason = None
493493
choices = response.get("choices") or []
494494
if choices:
495-
output_message_list = [choices[0].get("message")]
495+
output_message_list = [
496+
choices[0].get("message") or {"content": choices[0].get("text"), "role": "assistant"}
497+
]
496498
finish_reason = choices[0].get("finish_reason")
497499
else:
498500
response_model = kwargs.get("response.model")
@@ -507,7 +509,7 @@ def _record_completion_success(transaction, linking_metadata, completion_id, kwa
507509

508510
request_id = response_headers.get("x-request-id")
509511
organization = response_headers.get("openai-organization") or getattr(response, "organization", None)
510-
messages = kwargs.get("messages", None) or []
512+
messages = kwargs.get("messages") or [{"content": kwargs.get("prompt"), "role": "user"}] or []
511513
input_message_list = list(messages)
512514
full_chat_completion_summary_dict = {
513515
"id": completion_id,

0 commit comments

Comments
 (0)