Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions agentops/sdk/decorators/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
# Create a proxy class that wraps the original class
class WrappedClass(wrapped):
def __init__(self, *args, **kwargs):
# Start span when instance is created
operation_name = name or wrapped.__name__
self._agentops_span_context_manager = _create_as_current_span(operation_name, entity_kind, version)
self._agentops_active_span = self._agentops_span_context_manager.__enter__()
Expand All @@ -51,15 +50,34 @@
# Call the original __init__
super().__init__(*args, **kwargs)

def __del__(self):
# End span when instance is destroyed
async def __aenter__(self):
# Added for async context manager support
# This allows using the class with 'async with' statement

# If span is already created in __init__, just return self
if hasattr(self, "_agentops_active_span") and self._agentops_active_span is not None:
return self

# Otherwise create span (for backward compatibility)
operation_name = name or wrapped.__name__
self._agentops_span_context_manager = _create_as_current_span(operation_name, entity_kind, version)
self._agentops_active_span = self._agentops_span_context_manager.__enter__()
return self

Check warning on line 65 in agentops/sdk/decorators/factory.py

View check run for this annotation

Codecov / codecov/patch

agentops/sdk/decorators/factory.py#L62-L65

Added lines #L62 - L65 were not covered by tests

async def __aexit__(self, exc_type, exc_val, exc_tb):
# Added for proper async cleanup
# This ensures spans are properly closed when using 'async with'

if hasattr(self, "_agentops_active_span") and hasattr(self, "_agentops_span_context_manager"):
try:
_record_entity_output(self._agentops_active_span, self)
except Exception as e:
logger.warning(f"Failed to record entity output: {e}")

self._agentops_span_context_manager.__exit__(None, None, None)
self._agentops_span_context_manager.__exit__(exc_type, exc_val, exc_tb)
# Clear the span references after cleanup
self._agentops_span_context_manager = None
self._agentops_active_span = None

# Preserve metadata of the original class
WrappedClass.__name__ = wrapped.__name__
Expand Down Expand Up @@ -136,11 +154,13 @@
with _create_as_current_span(operation_name, entity_kind, version) as span:
try:
_record_entity_input(span, args, kwargs)

except Exception as e:
logger.warning(f"Failed to record entity input: {e}")

try:
result = wrapped(*args, **kwargs)

try:
_record_entity_output(span, result)
except Exception as e:
Expand Down
26 changes: 25 additions & 1 deletion tests/unit/sdk/test_decorators.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import AsyncGenerator
import asyncio

import pytest

from agentops.sdk.decorators import agent, operation, session, workflow, task
from agentops.semconv import SpanKind
from agentops.semconv.span_attributes import SpanAttributes
from tests.unit.sdk.instrumentation_tester import InstrumentationTester
from agentops.sdk.decorators.factory import create_entity_decorator


class TestSpanNesting:
Expand Down Expand Up @@ -600,3 +601,26 @@ def test_workflow_session():
assert transform_task.parent is not None
assert workflow_span.context is not None
assert transform_task.parent.span_id == workflow_span.context.span_id


@pytest.mark.asyncio
async def test_async_context_manager():
"""
Tests async context manager functionality (__aenter__, __aexit__).
"""

# Create a simple decorated class
@create_entity_decorator("test")
class TestClass:
def __init__(self):
self.value = 42

# Cover __aenter__ and __aexit__ (normal exit)
async with TestClass() as instance:
assert hasattr(instance, "_agentops_active_span")
assert instance._agentops_active_span is not None

# Cover __aenter__ and __aexit__ (exceptional exit)
with pytest.raises(ValueError):
async with TestClass() as instance:
raise ValueError("Trigger exception for __aexit__ coverage")
Loading