Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 90 additions & 65 deletions ddtrace/llmobs/_llmobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def _do_annotations(self, span: Span) -> None:
with self._annotation_context_lock:
for _, context_id, annotation_kwargs in self._instance._annotations:
if current_context_id == context_id:
self.annotate(span, **annotation_kwargs)
self.annotate(span, **annotation_kwargs, _suppress_span_kind_error=True)

def _child_after_fork(self) -> None:
self._llmobs_span_writer = self._llmobs_span_writer.recreate()
Expand Down Expand Up @@ -505,7 +505,7 @@ def _stop_service(self) -> None:
core.reset_listeners("trace.span_start", self._on_span_start)
core.reset_listeners("trace.span_finish", self._on_span_finish)
core.reset_listeners("http.span_inject", self._inject_llmobs_context)
core.reset_listeners("http.activate_distributed_headers", self._activate_llmobs_distributed_context)
core.reset_listeners("http.activate_distributed_headers", self._activate_llmobs_distributed_context_soft_fail)
core.reset_listeners("threading.submit", self._current_trace_context)
core.reset_listeners("threading.execution", self._llmobs_context_provider.activate)
core.reset_listeners("asyncio.create_task", self._on_asyncio_create_task)
Expand Down Expand Up @@ -620,7 +620,7 @@ def enable(
core.on("trace.span_start", cls._instance._on_span_start)
core.on("trace.span_finish", cls._instance._on_span_finish)
core.on("http.span_inject", cls._inject_llmobs_context)
core.on("http.activate_distributed_headers", cls._activate_llmobs_distributed_context)
core.on("http.activate_distributed_headers", cls._activate_llmobs_distributed_context_soft_fail)
core.on("threading.submit", cls._instance._current_trace_context, "llmobs_ctx")
core.on("threading.execution", cls._instance._llmobs_context_provider.activate)
core.on("asyncio.create_task", cls._instance._on_asyncio_create_task)
Expand Down Expand Up @@ -1014,16 +1014,14 @@ def export_span(cls, span: Optional[Span] = None) -> Optional[ExportedLLMObsSpan
try:
if span.span_type != SpanTypes.LLM:
error = "invalid_span"
log.warning("Span must be an LLMObs-generated span.")
return None
raise Exception("Span must be an LLMObs-generated span.")
return ExportedLLMObsSpan(
span_id=str(span.span_id),
trace_id=format_trace_id(span._get_ctx_item(LLMOBS_TRACE_ID) or span.trace_id),
)
except (TypeError, AttributeError):
error = "invalid_span"
log.warning("Failed to export span. Span must be a valid Span object.")
return None
raise Exception("Failed to export span. Span must be a valid Span object.") from None
finally:
telemetry.record_span_exported(span, error)

Expand Down Expand Up @@ -1338,6 +1336,7 @@ def annotate(
tags: Optional[Dict[str, Any]] = None,
tool_definitions: Optional[List[Dict[str, Any]]] = None,
_name: Optional[str] = None,
_suppress_span_kind_error: bool = False,
) -> None:
"""
Sets metadata, inputs, outputs, tags, and metrics as provided for a given LLMObs span.
Expand Down Expand Up @@ -1397,32 +1396,29 @@ def annotate(
span = cls._instance._current_span()
if span is None:
error = "invalid_span_no_active_spans"
log.warning("No span provided and no active LLMObs-generated span found.")
return
raise Exception("No span provided and no active LLMObs-generated span found.")
if span.span_type != SpanTypes.LLM:
error = "invalid_span_type"
log.warning("Span must be an LLMObs-generated span.")
return
raise Exception("Span must be an LLMObs-generated span.")
if span.finished:
error = "invalid_finished_span"
log.warning("Cannot annotate a finished span.")
return
raise Exception("Cannot annotate a finished span.")
if metadata is not None:
if not isinstance(metadata, dict):
error = "invalid_metadata"
log.warning("metadata must be a dictionary")
raise Exception("metadata must be a dictionary")
else:
cls._set_dict_attribute(span, METADATA, metadata)
if metrics is not None:
if not isinstance(metrics, dict) or not all(isinstance(v, (int, float)) for v in metrics.values()):
error = "invalid_metrics"
log.warning("metrics must be a dictionary of string key - numeric value pairs.")
raise Exception("metrics must be a dictionary of string key - numeric value pairs.")
else:
cls._set_dict_attribute(span, METRICS, metrics)
if tags is not None:
if not isinstance(tags, dict):
error = "invalid_tags"
log.warning("span tags must be a dictionary of string key - primitive value pairs.")
raise Exception("span tags must be a dictionary of string key - primitive value pairs.")
else:
session_id = tags.get("session_id")
if session_id:
Expand All @@ -1441,10 +1437,11 @@ def annotate(
cls._set_dict_attribute(span, INPUT_PROMPT, validated_prompt)
except (ValueError, TypeError) as e:
error = "invalid_prompt"
log.warning("Failed to validate prompt with error:", str(e), exc_info=True)
if not span_kind:
log.debug("Span kind not specified, skipping annotation for input/output data")
return
raise Exception("Failed to validate prompt with error:", str(e))
if (
not span_kind and not _suppress_span_kind_error
): # TODO(sabrenner): we should figure out how to remove this check for annotation contexts
raise Exception("Span kind not specified, skipping annotation for input/output data")
if input_data is not None or output_data is not None:
if span_kind == "llm":
error = cls._tag_llm_io(span, input_messages=input_data, output_messages=output_data)
Expand All @@ -1471,7 +1468,9 @@ def _tag_llm_io(cls, span, input_messages=None, output_messages=None) -> Optiona
if input_messages.messages:
span._set_ctx_item(INPUT_MESSAGES, input_messages.messages)
except TypeError:
log.warning("Failed to parse input messages.", exc_info=True)
log.warning(
"Failed to parse input messages.", exc_info=True
) # TODO: figure out how to raise this error and return the error type
return "invalid_io_messages"
if output_messages is None:
return None
Expand All @@ -1482,7 +1481,9 @@ def _tag_llm_io(cls, span, input_messages=None, output_messages=None) -> Optiona
return None
span._set_ctx_item(OUTPUT_MESSAGES, output_messages.messages)
except TypeError:
log.warning("Failed to parse output messages.", exc_info=True)
log.warning(
"Failed to parse output messages.", exc_info=True
) # TODO: figure out how to raise this error and return the error type
return "invalid_io_messages"
return None

Expand All @@ -1498,7 +1499,9 @@ def _tag_embedding_io(cls, span, input_documents=None, output_text=None) -> Opti
if input_documents.documents:
span._set_ctx_item(INPUT_DOCUMENTS, input_documents.documents)
except TypeError:
log.warning("Failed to parse input documents.", exc_info=True)
log.warning(
"Failed to parse input documents.", exc_info=True
) # TODO: figure out how to raise this error and return the error type
return "invalid_embedding_io"
if output_text is None:
return None
Expand All @@ -1521,7 +1524,9 @@ def _tag_retrieval_io(cls, span, input_text=None, output_documents=None) -> Opti
return None
span._set_ctx_item(OUTPUT_DOCUMENTS, output_documents.documents)
except TypeError:
log.warning("Failed to parse output documents.", exc_info=True)
log.warning(
"Failed to parse output documents.", exc_info=True
) # TODO: figure out how to raise this error and return the error type
return "invalid_retrieval_io"
return None

Expand Down Expand Up @@ -1712,17 +1717,15 @@ def submit_evaluation(
raise TypeError("value must be a boolean for a boolean metric.")

if tags is not None and not isinstance(tags, dict):
log.warning("tags must be a dictionary of string key-value pairs.")
tags = {}
raise Exception("tags must be a dictionary of string key-value pairs.")

ml_app = ml_app if ml_app else config._llmobs_ml_app
if not ml_app:
error = "missing_ml_app"
log.warning(
raise Exception(
"ML App name is required for sending evaluation metrics. Evaluation metric data will not be sent. "
"Ensure this configuration is set before running your application."
)
return

evaluation_tags = {
"ddtrace.version": ddtrace.__version__,
Expand All @@ -1735,7 +1738,7 @@ def submit_evaluation(
evaluation_tags[ensure_text(k)] = ensure_text(v)
except TypeError:
error = "invalid_tags"
log.warning("Failed to parse tags. Tags for evaluation metrics must be strings.")
raise Exception("Failed to parse tags. Tags for evaluation metrics must be strings.")

evaluation_metric: LLMObsEvaluationMetricEvent = {
"join_on": join_on,
Expand All @@ -1750,20 +1753,20 @@ def submit_evaluation(
if assessment:
if not isinstance(assessment, str) or assessment not in ("pass", "fail"):
error = "invalid_assessment"
log.warning("Failed to parse assessment. assessment must be either 'pass' or 'fail'.")
raise Exception("Failed to parse assessment. assessment must be either 'pass' or 'fail'.")
else:
evaluation_metric["assessment"] = assessment
if reasoning:
if not isinstance(reasoning, str):
error = "invalid_reasoning"
log.warning("Failed to parse reasoning. reasoning must be a string.")
raise Exception("Failed to parse reasoning. reasoning must be a string.")
else:
evaluation_metric["reasoning"] = reasoning

if metadata:
if not isinstance(metadata, dict):
error = "invalid_metadata"
log.warning("metadata must be json serializable dictionary.")
raise Exception("metadata must be json serializable dictionary.")
else:
metadata = safe_json(metadata)
if metadata and isinstance(metadata, str):
Expand Down Expand Up @@ -1801,7 +1804,9 @@ def _inject_llmobs_context(cls, span_context: Context, request_headers: Dict[str
span_context._meta[PROPAGATED_ML_APP_KEY] = ml_app

@classmethod
def inject_distributed_headers(cls, request_headers: Dict[str, str], span: Optional[Span] = None) -> Dict[str, str]:
def inject_distributed_headers(
cls, request_headers: Dict[str, str], span: Optional[Span] = None, _soft_fail: bool = False
) -> Dict[str, str]:
"""Injects the span's distributed context into the given request headers."""
if cls.enabled is False:
log.warning(
Expand All @@ -1813,53 +1818,74 @@ def inject_distributed_headers(cls, request_headers: Dict[str, str], span: Optio
try:
if not isinstance(request_headers, dict):
error = "invalid_request_headers"
log.warning("request_headers must be a dictionary of string key-value pairs.")
return request_headers
if _soft_fail:
log.warning("request_headers must be a dictionary of string key-value pairs.")
return request_headers
else:
raise Exception("request_headers must be a dictionary of string key-value pairs.")
if span is None:
span = cls._instance.tracer.current_span()
if span is None:
error = "no_active_span"
log.warning("No span provided and no currently active span found.")
return request_headers
if _soft_fail:
log.warning("No span provided and no currently active span found.")
return request_headers
raise Exception("No span provided and no currently active span found.")
if not isinstance(span, Span):
error = "invalid_span"
log.warning("span must be a valid Span object. Distributed context will not be injected.")
return request_headers
if _soft_fail:
log.warning("span must be a valid Span object. Distributed context will not be injected.")
return request_headers
raise Exception("span must be a valid Span object. Distributed context will not be injected.")
HTTPPropagator.inject(span.context, request_headers)
return request_headers
finally:
telemetry.record_inject_distributed_headers(error)

@classmethod
def _activate_llmobs_distributed_context(cls, request_headers: Dict[str, str], context: Context) -> Optional[str]:
if cls.enabled is False:
return None
if not context.trace_id or not context.span_id:
log.warning("Failed to extract trace/span ID from request headers.")
return "missing_context"
_parent_id = context._meta.get(PROPAGATED_PARENT_ID_KEY)
if _parent_id is None:
log.debug("Failed to extract LLMObs parent ID from request headers.")
return "missing_parent_id"
def _activate_llmobs_distributed_context_soft_fail(cls, request_headers: Dict[str, str], context: Context) -> None:
cls._activate_llmobs_distributed_context(request_headers, context, _soft_fail=True)

@classmethod
def _activate_llmobs_distributed_context(
cls, request_headers: Dict[str, str], context: Context, _soft_fail: bool = False
) -> None:
error = None
try:
parent_id = int(_parent_id)
except ValueError:
log.warning("Failed to parse LLMObs parent ID from request headers.")
return "invalid_parent_id"
parent_llmobs_trace_id = context._meta.get(PROPAGATED_LLMOBS_TRACE_ID_KEY)
if parent_llmobs_trace_id is None:
log.debug("Failed to extract LLMObs trace ID from request headers. Expected string, got None.")
if cls.enabled is False:
return
if not context.trace_id or not context.span_id:
error = "missing_context"
if _soft_fail:
log.warning("Failed to extract trace/span ID from request headers.")
return
raise Exception("Failed to extract trace/span ID from request headers.")
_parent_id = context._meta.get(PROPAGATED_PARENT_ID_KEY)
if _parent_id is None:
error = "missing_parent_id"
log.debug("Failed to extract LLMObs parent ID from request headers.")
return
try:
parent_id = int(_parent_id)
except ValueError:
error = "invalid_parent_id"
log.warning("Failed to parse LLMObs parent ID from request headers.")
return
parent_llmobs_trace_id = context._meta.get(PROPAGATED_LLMOBS_TRACE_ID_KEY)
if parent_llmobs_trace_id is None:
log.debug("Failed to extract LLMObs trace ID from request headers. Expected string, got None.")
llmobs_context = Context(trace_id=context.trace_id, span_id=parent_id)
llmobs_context._meta[PROPAGATED_LLMOBS_TRACE_ID_KEY] = str(context.trace_id)
cls._instance._llmobs_context_provider.activate(llmobs_context)
error = "missing_parent_llmobs_trace_id"
llmobs_context = Context(trace_id=context.trace_id, span_id=parent_id)
llmobs_context._meta[PROPAGATED_LLMOBS_TRACE_ID_KEY] = str(context.trace_id)
llmobs_context._meta[PROPAGATED_LLMOBS_TRACE_ID_KEY] = str(parent_llmobs_trace_id)
cls._instance._llmobs_context_provider.activate(llmobs_context)
return "missing_parent_llmobs_trace_id"
llmobs_context = Context(trace_id=context.trace_id, span_id=parent_id)
llmobs_context._meta[PROPAGATED_LLMOBS_TRACE_ID_KEY] = str(parent_llmobs_trace_id)
cls._instance._llmobs_context_provider.activate(llmobs_context)
return None
finally:
telemetry.record_activate_distributed_headers(error)

@classmethod
def activate_distributed_headers(cls, request_headers: Dict[str, str]) -> None:
def activate_distributed_headers(cls, request_headers: Dict[str, str], _soft_fail: bool = False) -> None:
"""
Activates distributed tracing headers for the current request.

Expand All @@ -1873,8 +1899,7 @@ def activate_distributed_headers(cls, request_headers: Dict[str, str]) -> None:
return
context = HTTPPropagator.extract(request_headers)
cls._instance.tracer.context_provider.activate(context)
error = cls._instance._activate_llmobs_distributed_context(request_headers, context)
telemetry.record_activate_distributed_headers(error)
cls._instance._activate_llmobs_distributed_context(request_headers, context, _soft_fail=_soft_fail)


# initialize the default llmobs instance
Expand Down
Loading
Loading