Skip to content

Commit 889852f

Browse files
authored
Fix for new OpenAI Agents SDK (#1152)
1 parent d419558 commit 889852f

File tree

4 files changed

+185
-170
lines changed

4 files changed

+185
-170
lines changed

logfire/_internal/integrations/openai_agents.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from logfire._internal.utils import handle_internal_errors, log_internal_error, truncate_string
4141

4242
if TYPE_CHECKING: # pragma: no cover
43-
from agents.tracing.setup import TraceProvider
43+
from agents.tracing import TraceProvider
4444
from openai.types.responses import Response
4545

4646
from logfire import Logfire, LogfireSpan
@@ -130,18 +130,28 @@ def __getattr__(self, item: Any) -> Any:
130130

131131
@classmethod
132132
def install(cls, logfire_instance: Logfire) -> None:
133-
name = 'GLOBAL_TRACE_PROVIDER'
134-
original = getattr(agents.tracing, name)
135-
if isinstance(original, cls):
136-
return
137-
wrapper = cls(original, logfire_instance)
138-
for module_name, mod in sys.modules.items():
139-
if module_name.startswith('agents'):
140-
try:
141-
if getattr(mod, name, None) is original:
142-
setattr(mod, name, wrapper)
143-
except Exception: # pragma: no cover
144-
pass
133+
try:
134+
from agents.tracing import get_trace_provider, set_trace_provider
135+
except ImportError: # pragma: no cover
136+
# Handle older versions of agents where these functions are not available
137+
name = 'GLOBAL_TRACE_PROVIDER'
138+
original = getattr(agents.tracing, name)
139+
if isinstance(original, cls):
140+
return
141+
wrapper = cls(original, logfire_instance)
142+
for module_name, mod in sys.modules.items():
143+
if module_name.startswith('agents'):
144+
try:
145+
if getattr(mod, name, None) is original:
146+
setattr(mod, name, wrapper)
147+
except Exception: # pragma: no cover
148+
pass
149+
else:
150+
original = get_trace_provider()
151+
if isinstance(original, cls):
152+
return
153+
wrapper = cls(original, logfire_instance)
154+
set_trace_provider(wrapper) # type: ignore
145155

146156

147157
@dataclass
@@ -185,13 +195,13 @@ class LogfireWrapperBase(Generic[T]):
185195
span_helper: LogfireSpanHelper
186196
token: contextvars.Token[T | None] | None = None
187197

188-
def start(self, mark_as_current: bool = False):
198+
def start(self, mark_as_current: bool = False) -> None:
189199
self.span_helper.start(mark_as_current)
190200
if mark_as_current:
191201
self.attach()
192202
return self.wrapped.start()
193203

194-
def finish(self, reset_current: bool = False):
204+
def finish(self, reset_current: bool = False) -> None:
195205
self.on_ending()
196206
self.span_helper.end(reset_current)
197207
if reset_current:
@@ -204,7 +214,7 @@ def __enter__(self) -> Self:
204214
self.attach()
205215
return self
206216

207-
def __exit__(self, exc_type: type[BaseException], exc_val: BaseException, exc_tb: TracebackType):
217+
def __exit__(self, exc_type: type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None:
208218
self.on_ending()
209219
self.span_helper.__exit__(exc_type, exc_val, exc_tb)
210220
self.wrapped.finish()

tests/conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import anyio._backends._asyncio # noqa # type: ignore
99
import pytest
10-
from agents.tracing.setup import GLOBAL_TRACE_PROVIDER
10+
from agents.tracing import get_trace_provider
1111
from opentelemetry import trace
1212
from opentelemetry.sdk._logs.export import SimpleLogRecordProcessor
1313
from opentelemetry.sdk.metrics.export import InMemoryMetricReader
@@ -28,8 +28,8 @@
2828
os.environ['LOGFIRE_TOKEN'] = ''
2929

3030

31-
GLOBAL_TRACE_PROVIDER.shutdown()
32-
GLOBAL_TRACE_PROVIDER.set_processors([])
31+
get_trace_provider().shutdown()
32+
get_trace_provider().set_processors([])
3333

3434

3535
@pytest.fixture(scope='session', autouse=True)

tests/otel_integrations/test_openai_agents.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
input_guardrail,
3131
trace,
3232
)
33+
from agents.tracing import get_trace_provider
3334
from agents.tracing.span_data import MCPListToolsSpanData, ResponseSpanData
3435
from agents.tracing.spans import NoOpSpan
3536
from agents.tracing.traces import NoOpTrace
@@ -1073,8 +1074,6 @@ def all_subclasses(cls: type) -> list[type]:
10731074
def test_unknown_span(exporter: TestExporter):
10741075
logfire.instrument_openai_agents()
10751076

1076-
from agents.tracing.setup import GLOBAL_TRACE_PROVIDER
1077-
10781077
class MySpanData(SpanData):
10791078
def export(self):
10801079
return {'foo': 'bar', 'type': self.type}
@@ -1085,7 +1084,7 @@ def type(self) -> str:
10851084

10861085
with trace('my_trace', trace_id='trace_123', group_id='456') as t:
10871086
assert t.name == 'my_trace'
1088-
with GLOBAL_TRACE_PROVIDER.create_span(span_data=MySpanData(), span_id='span_789') as s:
1087+
with get_trace_provider().create_span(span_data=MySpanData(), span_id='span_789') as s:
10891088
assert s.trace_id == 'trace_123'
10901089
assert s.span_id == 'span_789'
10911090
assert s.parent_id is None

0 commit comments

Comments
 (0)