Skip to content

Commit 1ba15f2

Browse files
committed
tests, lint, some type
1 parent ec6a564 commit 1ba15f2

File tree

19 files changed

+160
-132
lines changed

19 files changed

+160
-132
lines changed

guardrails/api_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import os
3-
from typing import Any, Iterable, Optional
3+
from typing import Any, Iterator, Optional
44

55
import requests
66
from guardrails_api_client.configuration import Configuration
@@ -80,7 +80,7 @@ def stream_validate(
8080
guard: Guard,
8181
payload: ValidatePayload,
8282
openai_api_key: Optional[str] = None,
83-
) -> Iterable[Any]:
83+
) -> Iterator[Any]:
8484
_openai_api_key = (
8585
openai_api_key
8686
if openai_api_key is not None

guardrails/async_guard.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from opentelemetry import context as otel_context
55
from typing import (
66
Any,
7-
AsyncIterable,
7+
AsyncIterator,
88
Awaitable,
99
Callable,
1010
Dict,
@@ -188,7 +188,7 @@ async def _execute(
188188
) -> Union[
189189
ValidationOutcome[OT],
190190
Awaitable[ValidationOutcome[OT]],
191-
AsyncIterable[ValidationOutcome[OT]],
191+
AsyncIterator[ValidationOutcome[OT]],
192192
]:
193193
self._fill_validator_map()
194194
self._fill_validators()
@@ -220,7 +220,7 @@ async def __exec(
220220
) -> Union[
221221
ValidationOutcome[OT],
222222
Awaitable[ValidationOutcome[OT]],
223-
AsyncIterable[ValidationOutcome[OT]],
223+
AsyncIterator[ValidationOutcome[OT]],
224224
]:
225225
prompt_params = prompt_params or {}
226226
metadata = metadata or {}
@@ -334,7 +334,7 @@ async def _exec(
334334
) -> Union[
335335
ValidationOutcome[OT],
336336
Awaitable[ValidationOutcome[OT]],
337-
AsyncIterable[ValidationOutcome[OT]],
337+
AsyncIterator[ValidationOutcome[OT]],
338338
]:
339339
"""Call the LLM asynchronously and validate the output.
340340
@@ -416,7 +416,7 @@ async def __call__(
416416
) -> Union[
417417
ValidationOutcome[OT],
418418
Awaitable[ValidationOutcome[OT]],
419-
AsyncIterable[ValidationOutcome[OT]],
419+
AsyncIterator[ValidationOutcome[OT]],
420420
]:
421421
"""Call the LLM and validate the output. Pass an async LLM API to
422422
return a coroutine.
@@ -534,7 +534,7 @@ async def parse(
534534

535535
async def _stream_server_call(
536536
self, *, payload: Dict[str, Any]
537-
) -> AsyncIterable[ValidationOutcome[OT]]:
537+
) -> AsyncIterator[ValidationOutcome[OT]]:
538538
# TODO: Once server side supports async streaming, this function will need to
539539
# yield async generators, not generators
540540
if self._api_client:

guardrails/classes/llm/llm_response.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
from itertools import tee
3-
from typing import Any, Dict, Iterable, Optional, AsyncIterable
3+
from typing import Any, Dict, Iterator, Optional, AsyncIterator
44

55
from guardrails_api_client import LLMResponse as ILLMResponse
66
from pydantic.config import ConfigDict
@@ -19,9 +19,9 @@ class LLMResponse(ILLMResponse):
1919
2020
Attributes:
2121
output (str): The output from the LLM.
22-
stream_output (Optional[Iterable]): A stream of output from the LLM.
22+
stream_output (Optional[Iterator]): A stream of output from the LLM.
2323
Default None.
24-
async_stream_output (Optional[AsyncIterable]): An async stream of output
24+
async_stream_output (Optional[AsyncIterator]): An async stream of output
2525
from the LLM. Default None.
2626
prompt_token_count (Optional[int]): The number of tokens in the prompt.
2727
Default None.
@@ -35,8 +35,8 @@ class LLMResponse(ILLMResponse):
3535
prompt_token_count: Optional[int] = None
3636
response_token_count: Optional[int] = None
3737
output: str
38-
stream_output: Optional[Iterable] = None
39-
async_stream_output: Optional[AsyncIterable] = None
38+
stream_output: Optional[Iterator] = None
39+
async_stream_output: Optional[AsyncIterator] = None
4040

4141
def to_interface(self) -> ILLMResponse:
4242
stream_output = None

guardrails/cli/telemetry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
def trace_if_enabled(command_name: str):
88
if settings.rc.enable_metrics is True:
9-
telemetry = HubTelemetry(enabled=True)
9+
telemetry = HubTelemetry()
10+
telemetry._enabled = True
1011
telemetry.create_new_span(
1112
f"guardrails-cli/{command_name}",
1213
[

guardrails/guard.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
Callable,
88
Dict,
99
Generic,
10-
Iterable,
10+
Iterator,
1111
List,
1212
Optional,
1313
Sequence,
@@ -302,11 +302,13 @@ def _configure_hub_telemtry(
302302

303303
self._allow_metrics_collection = allow_metrics_collection
304304

305+
# Initialize Hub Telemetry singleton and get the tracer
306+
self._hub_telemetry = HubTelemetry()
307+
self._hub_telemetry._enabled = allow_metrics_collection
308+
305309
if allow_metrics_collection is True:
306310
# Get unique id of user from rc file
307311
self._user_id = settings.rc.id or ""
308-
# Initialize Hub Telemetry singleton and get the tracer
309-
self._hub_telemetry = HubTelemetry(enabled=True)
310312

311313
def _fill_validator_map(self):
312314
# dont init validators if were going to call the server
@@ -696,7 +698,7 @@ def _execute(
696698
metadata: Optional[Dict],
697699
full_schema_reask: Optional[bool] = None,
698700
**kwargs,
699-
) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]:
701+
) -> Union[ValidationOutcome[OT], Iterator[ValidationOutcome[OT]]]:
700702
self._fill_validator_map()
701703
self._fill_validators()
702704
self._fill_exec_opts(
@@ -837,7 +839,7 @@ def _exec(
837839
instructions: Optional[str] = None,
838840
msg_history: Optional[List[Dict]] = None,
839841
**kwargs,
840-
) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]:
842+
) -> Union[ValidationOutcome[OT], Iterator[ValidationOutcome[OT]]]:
841843
api = None
842844

843845
if llm_api is not None or kwargs.get("model") is not None:
@@ -901,7 +903,7 @@ def __call__(
901903
metadata: Optional[Dict] = None,
902904
full_schema_reask: Optional[bool] = None,
903905
**kwargs,
904-
) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]:
906+
) -> Union[ValidationOutcome[OT], Iterator[ValidationOutcome[OT]]]:
905907
"""Call the LLM and validate the output.
906908
907909
Args:
@@ -1187,7 +1189,7 @@ def _stream_server_call(
11871189
self,
11881190
*,
11891191
payload: Dict[str, Any],
1190-
) -> Iterable[ValidationOutcome[OT]]:
1192+
) -> Iterator[ValidationOutcome[OT]]:
11911193
if settings.use_server and self._api_client:
11921194
validation_output: Optional[IValidationOutcome] = None
11931195
response = self._api_client.stream_validate(
@@ -1237,7 +1239,7 @@ def _call_server(
12371239
metadata: Optional[Dict] = {},
12381240
full_schema_reask: Optional[bool] = True,
12391241
**kwargs,
1240-
) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]:
1242+
) -> Union[ValidationOutcome[OT], Iterator[ValidationOutcome[OT]]]:
12411243
if settings.use_server and self._api_client:
12421244
payload: Dict[str, Any] = {
12431245
"args": list(args),

guardrails/hub_telemetry/hub_tracing.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
from functools import wraps
2-
from typing import Any, Awaitable, Callable, Dict, Optional
2+
from typing import (
3+
Any,
4+
AsyncIterator,
5+
Awaitable,
6+
Callable,
7+
Dict,
8+
Iterator,
9+
Optional,
10+
TypeVar,
11+
)
312

413
from opentelemetry.trace import Span
514

@@ -9,6 +18,8 @@
918
from guardrails.utils.safe_get import safe_get
1019
from guardrails.utils.hub_telemetry_utils import HubTelemetry
1120

21+
R = TypeVar("R", covariant=True)
22+
1223

1324
def get_guard_call_attributes(
1425
attrs: Dict[str, Any], origin: str, *args, **kwargs
@@ -76,7 +87,9 @@ def get_validator_inference_attributes(
7687
def get_validator_usage_attributes(
7788
attrs: Dict[str, Any], response, *args, **kwargs
7889
) -> Dict[str, Any]:
79-
validator_self = safe_get(args, 0)
90+
# We're wrapping a wrapped function,
91+
# so the first arg is the validator service
92+
validator_self = safe_get(args, 1)
8093
if validator_self is not None:
8194
attrs["validator_name"] = validator_self.rail_alias
8295
attrs["validator_on_fail"] = validator_self.on_fail_descriptor
@@ -90,11 +103,17 @@ def get_validator_usage_attributes(
90103

91104

92105
def add_attributes(
93-
span: Span, attrs: Dict[str, Any], name: str, origin: str, response, *args, **kwargs
106+
span: Span,
107+
attrs: Dict[str, Any],
108+
name: str,
109+
origin: str,
110+
*args,
111+
response=None,
112+
**kwargs,
94113
):
95114
attrs["origin"] = origin
96115
if name == "/guard_call":
97-
attrs = get_guard_call_attributes(attrs, *args, **kwargs)
116+
attrs = get_guard_call_attributes(attrs, origin, *args, **kwargs)
98117
elif name == "/reasks":
99118
if response is not None and hasattr(response, "iterations"):
100119
attrs["reask_count"] = len(response.iterations) - 1
@@ -103,7 +122,7 @@ def add_attributes(
103122
elif name == "/validator_inference":
104123
attrs = get_validator_inference_attributes(attrs, *args, **kwargs)
105124
elif name == "/validator_usage":
106-
attrs = get_validator_usage_attributes(attrs, response * args, **kwargs)
125+
attrs = get_validator_usage_attributes(attrs, response, *args, **kwargs)
107126

108127
for key, value in attrs.items():
109128
if value is not None:
@@ -117,9 +136,9 @@ def trace(
117136
is_parent: Optional[bool] = False,
118137
**attrs,
119138
):
120-
def decorator(fn: Callable[..., Any]):
139+
def decorator(fn: Callable[..., R]):
121140
@wraps(fn)
122-
def wrapper(*args, **kwargs):
141+
def wrapper(*args, **kwargs) -> R:
123142
hub_telemetry = HubTelemetry()
124143
if hub_telemetry._enabled and hub_telemetry._tracer is not None:
125144
context = (
@@ -137,7 +156,9 @@ def wrapper(*args, **kwargs):
137156
origin = origin if origin is not None else name
138157

139158
resp = fn(*args, **kwargs)
140-
add_attributes(span, attrs, origin, resp, *args, **kwargs)
159+
add_attributes(
160+
span, attrs, name, origin, *args, response=resp, **kwargs
161+
)
141162
return resp
142163
else:
143164
return fn(*args, **kwargs)
@@ -153,9 +174,9 @@ def async_trace(
153174
origin: Optional[str] = None,
154175
is_parent: Optional[bool] = False,
155176
):
156-
def decorator(fn: Callable[..., Awaitable[Any]]):
177+
def decorator(fn: Callable[..., Awaitable[R]]):
157178
@wraps(fn)
158-
async def async_wrapper(*args, **kwargs):
179+
async def async_wrapper(*args, **kwargs) -> R:
159180
hub_telemetry = HubTelemetry()
160181
if hub_telemetry._enabled and hub_telemetry._tracer is not None:
161182
context = (
@@ -170,7 +191,7 @@ async def async_wrapper(*args, **kwargs):
170191

171192
nonlocal origin
172193
origin = origin if origin is not None else name
173-
add_attributes(span, {"async": True}, origin, *args, **kwargs)
194+
add_attributes(span, {"async": True}, name, origin, *args, **kwargs)
174195
return await fn(*args, **kwargs)
175196
else:
176197
return await fn(*args, **kwargs)
@@ -193,9 +214,9 @@ def trace_stream(
193214
is_parent: Optional[bool] = False,
194215
**attrs,
195216
):
196-
def decorator(fn: Callable[..., Any]):
217+
def decorator(fn: Callable[..., Iterator[R]]):
197218
@wraps(fn)
198-
def wrapper(*args, **kwargs):
219+
def wrapper(*args, **kwargs) -> Iterator[R]:
199220
hub_telemetry = HubTelemetry()
200221
if hub_telemetry._enabled and hub_telemetry._tracer is not None:
201222
context = (
@@ -212,7 +233,7 @@ def wrapper(*args, **kwargs):
212233

213234
nonlocal origin
214235
origin = origin if origin is not None else name
215-
add_attributes(span, attrs, name, origin, None, *args, **kwargs)
236+
add_attributes(span, attrs, name, origin, *args, **kwargs)
216237
return _run_gen(fn, *args, **kwargs)
217238
else:
218239
return fn(*args, **kwargs)
@@ -235,9 +256,9 @@ def async_trace_stream(
235256
is_parent: Optional[bool] = False,
236257
**attrs,
237258
):
238-
def decorator(fn: Callable[..., Awaitable[Any]]):
259+
def decorator(fn: Callable[..., AsyncIterator[R]]):
239260
@wraps(fn)
240-
async def wrapper(*args, **kwargs):
261+
async def wrapper(*args, **kwargs) -> AsyncIterator[R]:
241262
hub_telemetry = HubTelemetry()
242263
if hub_telemetry._enabled and hub_telemetry._tracer is not None:
243264
context = (
@@ -254,7 +275,7 @@ async def wrapper(*args, **kwargs):
254275

255276
nonlocal origin
256277
origin = origin if origin is not None else name
257-
add_attributes(span, attrs, name, origin, None, *args, **kwargs)
278+
add_attributes(span, attrs, name, origin, *args, **kwargs)
258279
return _run_async_gen(fn, *args, **kwargs)
259280
else:
260281
return fn(*args, **kwargs)

0 commit comments

Comments
 (0)