Skip to content

Commit 57fbb04

Browse files
committed
trace async guard
1 parent 27eb0b4 commit 57fbb04

File tree

2 files changed

+73
-60
lines changed

2 files changed

+73
-60
lines changed

guardrails/async_guard.py

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
set_tracer,
3838
set_tracer_context,
3939
)
40+
from guardrails.telemetry.hub_tracing import async_trace
4041
from guardrails.types.pydantic import ModelOrListOfModels
4142
from guardrails.types.validator import UseManyValidatorSpec, UseValidatorSpec
4243
from guardrails.telemetry import trace_async_guard_execution, wrap_with_otel_context
@@ -226,42 +227,6 @@ async def __exec(
226227
if full_schema_reask is None:
227228
full_schema_reask = self._base_model is not None
228229

229-
if self._allow_metrics_collection:
230-
llm_api_str = ""
231-
if llm_api:
232-
llm_api_module_name = (
233-
llm_api.__module__ if hasattr(llm_api, "__module__") else ""
234-
)
235-
llm_api_name = (
236-
llm_api.__name__
237-
if hasattr(llm_api, "__name__")
238-
else type(llm_api).__name__
239-
)
240-
llm_api_str = f"{llm_api_module_name}.{llm_api_name}"
241-
# Create a new span for this guard call
242-
self._hub_telemetry.create_new_span(
243-
span_name="/guard_call",
244-
attributes=[
245-
("guard_id", self.id),
246-
("user_id", self._user_id),
247-
("llm_api", llm_api_str),
248-
(
249-
"custom_reask_prompt",
250-
self._exec_opts.reask_prompt is not None,
251-
),
252-
(
253-
"custom_reask_instructions",
254-
self._exec_opts.reask_instructions is not None,
255-
),
256-
(
257-
"custom_reask_messages",
258-
self._exec_opts.reask_messages is not None,
259-
),
260-
],
261-
is_parent=True, # It will have children
262-
has_parent=False, # Has no parents
263-
)
264-
265230
set_call_kwargs(kwargs)
266231
set_tracer(self._tracer)
267232
set_tracer_context(self._tracer_context)
@@ -435,6 +400,7 @@ async def _exec(
435400
)
436401
return ValidationOutcome[OT].from_guard_history(call)
437402

403+
@async_trace(name="/guard_call", origin="AsyncGuard.__call__")
438404
async def __call__(
439405
self,
440406
llm_api: Optional[Callable[..., Awaitable[Any]]] = None,
@@ -501,6 +467,7 @@ async def __call__(
501467
**kwargs,
502468
)
503469

470+
@async_trace(name="/guard_call", origin="AsyncGuard.parse")
504471
async def parse(
505472
self,
506473
llm_output: str,
@@ -609,6 +576,7 @@ async def _stream_server_call(
609576
else:
610577
raise ValueError("AsyncGuard does not have an api client!")
611578

579+
@async_trace(name="/guard_call", origin="AsyncGuard.validate")
612580
async def validate(
613581
self, llm_output: str, *args, **kwargs
614582
) -> Awaitable[ValidationOutcome[OT]]:
Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import wraps
2-
from typing import Any, Callable, Dict, Optional
2+
from typing import Any, Awaitable, Callable, Dict, Optional
33

44
from opentelemetry.trace import Span
55

@@ -8,29 +8,36 @@
88
from guardrails.utils.hub_telemetry_utils import HubTelemetry
99

1010

11-
def get_guard_attributes(attrs: Dict[str, Any], guard_self: Any) -> Dict[str, Any]:
12-
attrs["guard_id"] = guard_self.id
13-
attrs["user_id"] = guard_self._user_id
14-
attrs["custom_reask_prompt"] = guard_self._exec_opts.reask_prompt is not None
15-
attrs["custom_reask_instructions"] = (
16-
guard_self._exec_opts.reask_instructions is not None
17-
)
18-
attrs["custom_reask_messages"] = guard_self._exec_opts.reask_messages is not None
19-
attrs["output_type"] = (
20-
"unstructured"
21-
if PrimitiveTypes.is_primitive(guard_self.output_schema.type.actual_instance)
22-
else "structured"
23-
)
24-
return attrs
25-
11+
def get_guard_call_attributes(
12+
attrs: Dict[str, Any], origin: str, *args, **kwargs
13+
) -> Dict[str, Any]:
14+
attrs["stream"] = kwargs.get("stream", False)
2615

27-
def get_guard_call_attributes(attrs: Dict[str, Any], *args, **kwargs) -> Dict[str, Any]:
2816
guard_self = safe_get(args, 0)
2917
if guard_self is not None:
30-
attrs = get_guard_attributes(attrs, guard_self)
18+
attrs["guard_id"] = guard_self.id
19+
attrs["user_id"] = guard_self._user_id
20+
attrs["custom_reask_prompt"] = guard_self._exec_opts.reask_prompt is not None
21+
attrs["custom_reask_instructions"] = (
22+
guard_self._exec_opts.reask_instructions is not None
23+
)
24+
attrs["custom_reask_messages"] = (
25+
guard_self._exec_opts.reask_messages is not None
26+
)
27+
attrs["output_type"] = (
28+
"unstructured"
29+
if PrimitiveTypes.is_primitive(
30+
guard_self.output_schema.type.actual_instance
31+
)
32+
else "structured"
33+
)
34+
return attrs
3135

3236
llm_api_str = "" # noqa
33-
llm_api = safe_get(args, 1, kwargs.get("llm_api"))
37+
llm_api = kwargs.get("llm_api")
38+
if origin in ["Guard.__call__", "AsyncGuard.__call__"]:
39+
llm_api = safe_get(args, 1, llm_api)
40+
3441
if llm_api:
3542
llm_api_module_name = (
3643
llm_api.__module__ if hasattr(llm_api, "__module__") else ""
@@ -44,16 +51,24 @@ def get_guard_call_attributes(attrs: Dict[str, Any], *args, **kwargs) -> Dict[st
4451
return attrs
4552

4653

47-
def add_attributes(name: str, span: Span, origin: str, *args, **kwargs):
48-
attrs = {"origin": origin}
49-
if origin == "Guard.__call__":
54+
def add_attributes(
55+
span: Span, attrs: Dict[str, Any], name: str, origin: str, *args, **kwargs
56+
):
57+
attrs["origin"] = origin
58+
if name == "/guard_call":
5059
attrs = get_guard_call_attributes(attrs, *args, **kwargs)
5160

5261
for key, value in attrs.items():
5362
span.set_attribute(key, value)
5463

5564

56-
def trace(*, name: str, origin: str, is_parent: Optional[bool] = False):
65+
def trace(
66+
*,
67+
name: str,
68+
origin: str,
69+
is_parent: Optional[bool] = False,
70+
**attrs,
71+
):
5772
def decorator(fn: Callable[..., Any]):
5873
@wraps(fn)
5974
def wrapper(*args, **kwargs):
@@ -69,11 +84,41 @@ def wrapper(*args, **kwargs):
6984
# Inject the current context
7085
hub_telemetry.inject_current_context()
7186

72-
add_attributes(name, span, origin, *args, **kwargs)
87+
add_attributes(span, attrs, origin, *args, **kwargs)
7388
return fn(*args, **kwargs)
7489
else:
7590
return fn(*args, **kwargs)
7691

7792
return wrapper
7893

7994
return decorator
95+
96+
97+
def async_trace(
98+
*,
99+
name: str,
100+
origin: str,
101+
is_parent: Optional[bool] = False,
102+
):
103+
def decorator(fn: Callable[..., Awaitable[Any]]):
104+
@wraps(fn)
105+
async def async_wrapper(*args, **kwargs):
106+
hub_telemetry = HubTelemetry()
107+
if hub_telemetry._enabled and hub_telemetry._tracer is not None:
108+
context = (
109+
hub_telemetry.extract_current_context() if not is_parent else None
110+
)
111+
with hub_telemetry._tracer.start_as_current_span(
112+
name, context=context
113+
) as span: # noqa
114+
if is_parent:
115+
# Inject the current context
116+
hub_telemetry.inject_current_context()
117+
add_attributes(span, {"async": True}, origin, *args, **kwargs)
118+
return await fn(*args, **kwargs)
119+
else:
120+
return await fn(*args, **kwargs)
121+
122+
return async_wrapper
123+
124+
return decorator

0 commit comments

Comments
 (0)