Skip to content

Commit 56882fd

Browse files
committed
don't type trace wrappers
1 parent 1ba15f2 commit 56882fd

File tree

4 files changed

+21
-16
lines changed

4 files changed

+21
-16
lines changed

guardrails/classes/llm/llm_response.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def to_dict(self) -> Dict[str, Any]:
7373
def from_interface(cls, i_llm_response: ILLMResponse) -> "LLMResponse":
7474
stream_output = None
7575
if i_llm_response.stream_output:
76-
stream_output = [so for so in i_llm_response.stream_output]
76+
stream_output = iter([so for so in i_llm_response.stream_output])
7777

7878
async_stream_output = None
7979
if i_llm_response.async_stream_output:

guardrails/hub_telemetry/hub_tracing.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
from functools import wraps
22
from typing import (
33
Any,
4-
AsyncIterator,
5-
Awaitable,
6-
Callable,
74
Dict,
8-
Iterator,
95
Optional,
106
TypeVar,
117
)
@@ -136,9 +132,11 @@ def trace(
136132
is_parent: Optional[bool] = False,
137133
**attrs,
138134
):
139-
def decorator(fn: Callable[..., R]):
135+
# def decorator(fn: Callable[..., R]):
136+
def decorator(fn):
140137
@wraps(fn)
141-
def wrapper(*args, **kwargs) -> R:
138+
# def wrapper(*args, **kwargs) -> R:
139+
def wrapper(*args, **kwargs):
142140
hub_telemetry = HubTelemetry()
143141
if hub_telemetry._enabled and hub_telemetry._tracer is not None:
144142
context = (
@@ -174,9 +172,11 @@ def async_trace(
174172
origin: Optional[str] = None,
175173
is_parent: Optional[bool] = False,
176174
):
177-
def decorator(fn: Callable[..., Awaitable[R]]):
175+
# def decorator(fn: Callable[..., Awaitable[R]]):
176+
def decorator(fn):
178177
@wraps(fn)
179-
async def async_wrapper(*args, **kwargs) -> R:
178+
# async def async_wrapper(*args, **kwargs) -> R:
179+
async def async_wrapper(*args, **kwargs):
180180
hub_telemetry = HubTelemetry()
181181
if hub_telemetry._enabled and hub_telemetry._tracer is not None:
182182
context = (
@@ -214,9 +214,11 @@ def trace_stream(
214214
is_parent: Optional[bool] = False,
215215
**attrs,
216216
):
217-
def decorator(fn: Callable[..., Iterator[R]]):
217+
# def decorator(fn: Callable[..., Iterator[R]]):
218+
def decorator(fn):
218219
@wraps(fn)
219-
def wrapper(*args, **kwargs) -> Iterator[R]:
220+
# def wrapper(*args, **kwargs) -> Iterator[R]:
221+
def wrapper(*args, **kwargs):
220222
hub_telemetry = HubTelemetry()
221223
if hub_telemetry._enabled and hub_telemetry._tracer is not None:
222224
context = (
@@ -256,9 +258,10 @@ def async_trace_stream(
256258
is_parent: Optional[bool] = False,
257259
**attrs,
258260
):
259-
def decorator(fn: Callable[..., AsyncIterator[R]]):
261+
# def decorator(fn: Callable[..., AsyncIterator[R]]):
262+
def decorator(fn):
260263
@wraps(fn)
261-
async def wrapper(*args, **kwargs) -> AsyncIterator[R]:
264+
async def wrapper(*args, **kwargs):
262265
hub_telemetry = HubTelemetry()
263266
if hub_telemetry._enabled and hub_telemetry._tracer is not None:
264267
context = (

guardrails/integrations/databricks/ml_flow_instrumentor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
Awaitable,
88
Callable,
99
Coroutine,
10-
Generator,
1110
Iterator,
1211
Union,
1312
)
@@ -220,12 +219,12 @@ def trace_step_wrapper(*args, **kwargs) -> Iteration:
220219
return trace_step_wrapper
221220

222221
def _instrument_stream_runner_step(
223-
self, runner_step: Callable[..., Generator[ValidationOutcome[OT], None, None]]
222+
self, runner_step: Callable[..., Iterator[ValidationOutcome[OT]]]
224223
):
225224
@wraps(runner_step)
226225
def trace_stream_step_wrapper(
227226
*args, **kwargs
228-
) -> Generator[ValidationOutcome[OT], None, None]:
227+
) -> Iterator[ValidationOutcome[OT]]:
229228
with mlflow.start_span(
230229
name="guardrails/guard/step",
231230
span_type="step",

guardrails/validator_service/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
)
1313
from guardrails.types import ValidatorMap
1414
from guardrails.telemetry.legacy_validator_tracing import trace_validation_result
15+
16+
# Keep this imported for backwards compatibility
17+
from guardrails.validator_service.validator_service_base import ValidatorServiceBase # noqa
1518
from guardrails.validator_service.async_validator_service import AsyncValidatorService
1619
from guardrails.validator_service.sequential_validator_service import (
1720
SequentialValidatorService,

0 commit comments

Comments
 (0)