diff --git a/guardrails/telemetry/common.py b/guardrails/telemetry/common.py index 39892d060..d2cc53ca8 100644 --- a/guardrails/telemetry/common.py +++ b/guardrails/telemetry/common.py @@ -1,5 +1,6 @@ import json from typing import Any, Callable, Dict, Optional, Union +from opentelemetry.baggage import get_baggage from opentelemetry import context from opentelemetry.context import Context from opentelemetry.trace import Tracer, Span @@ -103,3 +104,23 @@ def wrapped_func(*args: Any, **kwargs: Any) -> Any: context.detach(token) return wrapped_func + + +def add_user_attributes(span: Span): + try: + client_ip = get_baggage("client.ip") or "unknown" + user_agent = get_baggage("http.user_agent") or "unknown" + referrer = get_baggage("http.referrer") or "unknown" + user_id = get_baggage("user.id") or "unknown" + organization = get_baggage("organization") or "unknown" + app = get_baggage("app") or "unknown" + + span.set_attribute("client.ip", str(client_ip)) + span.set_attribute("http.user_agent", str(user_agent)) + span.set_attribute("http.referrer", str(referrer)) + span.set_attribute("user.id", str(user_id)) + span.set_attribute("organization", str(organization)) + span.set_attribute("app", str(app)) + except Exception as e: + logger.warning("Error loading baggage user information", e) + pass diff --git a/guardrails/telemetry/guard_tracing.py b/guardrails/telemetry/guard_tracing.py index 535f2db4a..57e5a4614 100644 --- a/guardrails/telemetry/guard_tracing.py +++ b/guardrails/telemetry/guard_tracing.py @@ -19,6 +19,7 @@ from guardrails.classes.output_type import OT from guardrails.classes.validation_outcome import ValidationOutcome from guardrails.telemetry.open_inference import trace_operation +from guardrails.telemetry.common import add_user_attributes from guardrails.version import GUARDRAILS_VERSION @@ -141,6 +142,7 @@ def trace_stream_guard( # FIXME: This should only be called once; # Accumulate the validated output and call at the end add_guard_attributes(guard_span, history, res) + add_user_attributes(guard_span) yield res except StopIteration: next_exists = False @@ -175,6 +177,7 @@ def trace_guard_execution( ): return trace_stream_guard(guard_span, result, history) add_guard_attributes(guard_span, history, result) + add_user_attributes(guard_span) return result except Exception as e: guard_span.set_status(status=StatusCode.ERROR, description=str(e)) @@ -193,6 +196,7 @@ async def trace_async_stream_guard( try: res = await anext(result) # type: ignore add_guard_attributes(guard_span, history, res) + add_user_attributes(guard_span) yield res except StopIteration: next_exists = False @@ -244,9 +248,11 @@ async def trace_async_guard_execution( if inspect.isawaitable(result): res = await result add_guard_attributes(guard_span, history, res) # type: ignore + add_user_attributes(guard_span) return res except Exception as e: guard_span.set_status(status=StatusCode.ERROR, description=str(e)) + add_user_attributes(guard_span) raise e else: return await _execute_fn(*args, **kwargs) diff --git a/guardrails/telemetry/open_inference.py b/guardrails/telemetry/open_inference.py index 266c36c2c..7c58d8a84 100644 --- a/guardrails/telemetry/open_inference.py +++ b/guardrails/telemetry/open_inference.py @@ -1,7 +1,6 @@ from typing import Any, Dict, List, Optional -from guardrails.telemetry.common import get_span, to_dict -from guardrails.utils.serialization_utils import serialize +from guardrails.telemetry.common import get_span, to_dict, serialize def trace_operation( diff --git a/guardrails/telemetry/runner_tracing.py b/guardrails/telemetry/runner_tracing.py index 1b3bb346f..9226f9fb5 100644 --- a/guardrails/telemetry/runner_tracing.py +++ b/guardrails/telemetry/runner_tracing.py @@ -17,9 +17,8 @@ from guardrails.classes.output_type import OT from guardrails.classes.validation_outcome import ValidationOutcome from guardrails.stores.context import get_guard_name -from guardrails.telemetry.common import get_tracer +from guardrails.telemetry.common import get_tracer, add_user_attributes, serialize from guardrails.utils.safe_get import safe_get -from guardrails.utils.serialization_utils import serialize from guardrails.version import GUARDRAILS_VERSION @@ -73,10 +72,12 @@ def trace_step_wrapper(*args, **kwargs) -> Iteration: try: response = fn(*args, **kwargs) add_step_attributes(step_span, response, *args, **kwargs) + add_user_attributes(step_span) return response except Exception as e: step_span.set_status(status=StatusCode.ERROR, description=str(e)) add_step_attributes(step_span, None, *args, **kwargs) + add_user_attributes(step_span) raise e else: return fn(*args, **kwargs) @@ -112,6 +113,7 @@ def trace_stream_step_generator( call = safe_get(args, 8, kwargs.get("call_log", None)) iteration = call.iterations.last if call else None add_step_attributes(step_span, iteration, *args, **kwargs) + add_user_attributes(step_span) if exception: raise exception @@ -145,10 +147,12 @@ async def trace_async_step_wrapper(*args, **kwargs) -> Iteration: ) as step_span: try: response = await fn(*args, **kwargs) + add_user_attributes(step_span) add_step_attributes(step_span, response, *args, **kwargs) return response except Exception as e: step_span.set_status(status=StatusCode.ERROR, description=str(e)) + add_user_attributes(step_span) add_step_attributes(step_span, None, *args, **kwargs) raise e diff --git a/guardrails/telemetry/validator_tracing.py b/guardrails/telemetry/validator_tracing.py index 4f88cd23e..e9114b7be 100644 --- a/guardrails/telemetry/validator_tracing.py +++ b/guardrails/telemetry/validator_tracing.py @@ -13,11 +13,10 @@ from guardrails.settings import settings from guardrails.classes.validation.validation_result import ValidationResult -from guardrails.telemetry.common import get_tracer +from guardrails.telemetry.common import get_tracer, add_user_attributes, serialize from guardrails.telemetry.open_inference import trace_operation from guardrails.utils.casting_utils import to_string from guardrails.utils.safe_get import safe_get -from guardrails.utils.serialization_utils import serialize from guardrails.version import GUARDRAILS_VERSION @@ -106,6 +105,7 @@ def trace_validator_wrapper(*args, **kwargs): ) as validator_span: try: resp = fn(*args, **kwargs) + add_user_attributes(validator_span) add_validator_attributes( *args, validator_span=validator_span, @@ -122,6 +122,7 @@ def trace_validator_wrapper(*args, **kwargs): validator_span.set_status( status=StatusCode.ERROR, description=str(e) ) + add_user_attributes(validator_span) add_validator_attributes( *args, validator_span=validator_span, @@ -168,6 +169,7 @@ async def trace_validator_wrapper(*args, **kwargs): ) as validator_span: try: resp = await fn(*args, **kwargs) + add_user_attributes(validator_span) add_validator_attributes( *args, validator_span=validator_span, @@ -184,6 +186,7 @@ async def trace_validator_wrapper(*args, **kwargs): validator_span.set_status( status=StatusCode.ERROR, description=str(e) ) + add_user_attributes(validator_span) add_validator_attributes( *args, validator_span=validator_span,