Skip to content

Commit fdaa0ed

Browse files
committed
fix context issues
1 parent d3faf51 commit fdaa0ed

File tree

10 files changed

+38
-67
lines changed

10 files changed

+38
-67
lines changed

guardrails/cli/create.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
console = Console()
1616

1717

18-
@trace(name="guardrails-cli/create", is_parent=True)
1918
@gr_cli.command(name="create")
19+
@trace(name="guardrails-cli/create")
2020
def create_command(
2121
validators: Optional[str] = typer.Option(
2222
default="",

guardrails/cli/hub/create_validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def test_failure_case(self):
147147
)
148148

149149

150-
@trace(name="guardrails-cli/hub/create-validator", is_parent=True)
151150
@hub_command.command(name="create-validator")
151+
@trace(name="guardrails-cli/hub/create-validator")
152152
def create_validator(
153153
name: str = typer.Argument(help="The name for your validator."),
154154
filepath: str = typer.Argument(

guardrails/cli/hub/install.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from guardrails.cli.version import version_warnings_if_applicable
1111

1212

13-
@trace(name="guardrails-cli/hub/install", is_parent=True)
1413
@hub_command.command()
14+
@trace(name="guardrails-cli/hub/install")
1515
def install(
1616
package_uris: List[str] = typer.Argument(
1717
...,

guardrails/cli/hub/list.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from .console import console
88

99

10-
@trace(name="guardrails-cli/hub/list", is_parent=True)
1110
@hub_command.command(name="list")
11+
@trace(name="guardrails-cli/hub/list")
1212
def list():
1313
"""List all installed validators."""
1414
site_packages = get_site_packages_location()

guardrails/cli/hub/submit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from guardrails.hub_telemetry.hub_tracing import trace
1212

1313

14-
@trace(name="guardrails-cli/hub/submit", is_parent=True)
1514
@hub_command.command(name="submit")
15+
@trace(name="guardrails-cli/hub/submit")
1616
def submit(
1717
package_name: str = typer.Argument(help="The package name for your validator."),
1818
filepath: str = typer.Argument(

guardrails/cli/hub/uninstall.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def uninstall_hub_module(manifest: Manifest, site_packages: str):
7474
sys.exit(1)
7575

7676

77-
@trace(name="guardrails-cli/hub/uninstall", is_parent=True)
7877
@hub_command.command()
78+
@trace(name="guardrails-cli/hub/uninstall")
7979
def uninstall(
8080
package_uri: str = typer.Argument(
8181
help="URI to the package to uninstall. Example: hub://guardrails/regex_match."

guardrails/cli/telemetry.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,4 @@ def trace_if_enabled(command_name: str):
1919
("machine", platform.machine()),
2020
("processor", platform.processor()),
2121
],
22-
True,
23-
False,
2422
)

guardrails/cli/validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def validate_llm_output(rail: str, llm_output: str) -> Union[str, Dict, List, No
1515
return result.validated_output
1616

1717

18-
@trace(name="guardrails-cli/validate", is_parent=True)
1918
@guardrails.command()
19+
@trace(name="guardrails-cli/validate")
2020
def validate(
2121
rail: str = typer.Argument(
2222
..., help="Path to the rail spec.", exists=True, file_okay=True, dir_okay=False

guardrails/hub_telemetry/hub_tracing.py

Lines changed: 16 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
)
88

99
from opentelemetry.trace import Span
10+
from opentelemetry.trace.propagation import set_span_in_context
1011

1112
from guardrails.classes.validation.validation_result import ValidationResult
1213
from guardrails.hub_token.token import VALIDATOR_HUB_SERVICE
@@ -129,27 +130,20 @@ def trace(
129130
*,
130131
name: str,
131132
origin: Optional[str] = None,
132-
is_parent: Optional[bool] = False,
133133
**attrs,
134134
):
135-
# def decorator(fn: Callable[..., R]):
136135
def decorator(fn):
137136
@wraps(fn)
138-
# def wrapper(*args, **kwargs) -> R:
139137
def wrapper(*args, **kwargs):
140138
hub_telemetry = HubTelemetry()
141139
if hub_telemetry._enabled and hub_telemetry._tracer is not None:
142-
context = (
143-
hub_telemetry.extract_current_context() if not is_parent else None
144-
)
145140
with hub_telemetry._tracer.start_span(
146141
name,
147-
context=context,
142+
context=hub_telemetry.extract_current_context(),
148143
set_status_on_exception=True,
149144
) as span: # noqa
150-
if is_parent:
151-
# Inject the current context
152-
hub_telemetry.inject_current_context()
145+
context = set_span_in_context(span)
146+
hub_telemetry.inject_current_context(context=context)
153147
nonlocal origin
154148
origin = origin if origin is not None else name
155149

@@ -170,24 +164,19 @@ def async_trace(
170164
*,
171165
name: str,
172166
origin: Optional[str] = None,
173-
is_parent: Optional[bool] = False,
174167
):
175-
# def decorator(fn: Callable[..., Awaitable[R]]):
176168
def decorator(fn):
177169
@wraps(fn)
178-
# async def async_wrapper(*args, **kwargs) -> R:
179170
async def async_wrapper(*args, **kwargs):
180171
hub_telemetry = HubTelemetry()
181172
if hub_telemetry._enabled and hub_telemetry._tracer is not None:
182-
context = (
183-
hub_telemetry.extract_current_context() if not is_parent else None
184-
)
185-
with hub_telemetry._tracer.start_as_current_span(
186-
name, context=context
173+
with hub_telemetry._tracer.start_span(
174+
name,
175+
context=hub_telemetry.extract_current_context(),
176+
set_status_on_exception=True,
187177
) as span: # noqa
188-
if is_parent:
189-
# Inject the current context
190-
hub_telemetry.inject_current_context()
178+
context = set_span_in_context(span)
179+
hub_telemetry.inject_current_context(context=context)
191180

192181
nonlocal origin
193182
origin = origin if origin is not None else name
@@ -211,27 +200,20 @@ def trace_stream(
211200
*,
212201
name: str,
213202
origin: Optional[str] = None,
214-
is_parent: Optional[bool] = False,
215203
**attrs,
216204
):
217-
# def decorator(fn: Callable[..., Iterator[R]]):
218205
def decorator(fn):
219206
@wraps(fn)
220-
# def wrapper(*args, **kwargs) -> Iterator[R]:
221207
def wrapper(*args, **kwargs):
222208
hub_telemetry = HubTelemetry()
223209
if hub_telemetry._enabled and hub_telemetry._tracer is not None:
224-
context = (
225-
hub_telemetry.extract_current_context() if not is_parent else None
226-
)
227210
with hub_telemetry._tracer.start_span(
228211
name,
229-
context=context,
212+
context=hub_telemetry.extract_current_context(),
230213
set_status_on_exception=True,
231214
) as span: # noqa
232-
if is_parent:
233-
# Inject the current context
234-
hub_telemetry.inject_current_context()
215+
context = set_span_in_context(span)
216+
hub_telemetry.inject_current_context(context=context)
235217

236218
nonlocal origin
237219
origin = origin if origin is not None else name
@@ -255,26 +237,20 @@ def async_trace_stream(
255237
*,
256238
name: str,
257239
origin: Optional[str] = None,
258-
is_parent: Optional[bool] = False,
259240
**attrs,
260241
):
261-
# def decorator(fn: Callable[..., AsyncIterator[R]]):
262242
def decorator(fn):
263243
@wraps(fn)
264244
async def wrapper(*args, **kwargs):
265245
hub_telemetry = HubTelemetry()
266246
if hub_telemetry._enabled and hub_telemetry._tracer is not None:
267-
context = (
268-
hub_telemetry.extract_current_context() if not is_parent else None
269-
)
270247
with hub_telemetry._tracer.start_span(
271248
name,
272-
context=context,
249+
context=hub_telemetry.extract_current_context(),
273250
set_status_on_exception=True,
274251
) as span: # noqa
275-
if is_parent:
276-
# Inject the current context
277-
hub_telemetry.inject_current_context()
252+
context = set_span_in_context(span)
253+
hub_telemetry.inject_current_context(context=context)
278254

279255
nonlocal origin
280256
origin = origin if origin is not None else name

guardrails/utils/hub_telemetry_utils.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from opentelemetry.sdk.trace import TracerProvider
1010
from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor
1111
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
12+
from opentelemetry.trace.propagation import set_span_in_context
1213

1314

1415
class HubTelemetry:
@@ -57,10 +58,10 @@ def initialize_tracer(
5758
self._enabled = enabled
5859
self._carrier = {}
5960
self._service_name = service_name
60-
# self._endpoint = "http://localhost:4318/v1/traces"
61-
self._endpoint = (
62-
"https://hty0gc1ok3.execute-api.us-east-1.amazonaws.com/v1/traces"
63-
)
61+
self._endpoint = "http://localhost:5318/v1/traces"
62+
# self._endpoint = (
63+
# "https://hty0gc1ok3.execute-api.us-east-1.amazonaws.com/v1/traces"
64+
# )
6465
self._tracer_name = tracer_name
6566

6667
# Create a resource
@@ -85,11 +86,14 @@ def initialize_tracer(
8586

8687
self._prop = TraceContextTextMapPropagator()
8788

88-
def inject_current_context(self) -> None:
89+
def inject_current_context(self, context=None) -> None:
8990
"""Injects the current context into the carrier."""
9091
if not self._prop:
9192
return
92-
self._prop.inject(carrier=self._carrier)
93+
if context is not None:
94+
self._prop.inject(carrier=self._carrier, context=context)
95+
else:
96+
self._prop.inject(carrier=self._carrier)
9397

9498
def extract_current_context(self):
9599
"""Extracts the current context from the carrier."""
@@ -98,13 +102,7 @@ def extract_current_context(self):
98102
context = self._prop.extract(carrier=self._carrier)
99103
return context
100104

101-
def create_new_span(
102-
self,
103-
span_name: str,
104-
attributes: list,
105-
is_parent: bool, # Inject current context if IS a parent span
106-
has_parent: bool, # Extract current context if HAS a parent span
107-
):
105+
def create_new_span(self, span_name: str, attributes: list):
108106
"""Creates a new span within the tracer with the given name and
109107
attributes.
110108
@@ -121,13 +119,12 @@ def create_new_span(
121119
"""
122120
if self._tracer is None:
123121
return
124-
with self._tracer.start_as_current_span(
122+
with self._tracer.start_span(
125123
span_name, # type: ignore (Fails in Python 3.9 for invalid reason)
126-
context=self.extract_current_context() if has_parent else None,
124+
context=self.extract_current_context(),
127125
) as span:
128-
if is_parent:
129-
# Inject the current context
130-
self.inject_current_context()
126+
context = set_span_in_context(span)
127+
self.inject_current_context(context=context)
131128

132129
for attribute in attributes:
133130
span.set_attribute(attribute[0], attribute[1])

0 commit comments

Comments
 (0)