Skip to content

Commit 7e9addb

Browse files
mindpowerkthota-g
andauthored
test: add 14 unit tests for utils/helper.py, utils/telemetry.py (#33)
* Add unit tests for utils/helper.py,utils/telemetry.py * Update test_helpers.py * Update test_helpers.py --------- Co-authored-by: kthota-g <[email protected]>
1 parent c027843 commit 7e9addb

File tree

2 files changed

+308
-0
lines changed

2 files changed

+308
-0
lines changed

tests/utils/test_helpers.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from typing import Any
2+
import pytest
3+
from unittest.mock import MagicMock
4+
from uuid import uuid4
5+
from a2a.utils.helpers import (
6+
create_task_obj,
7+
append_artifact_to_task,
8+
build_text_artifact,
9+
validate,
10+
)
11+
from a2a.types import (
12+
Artifact,
13+
MessageSendParams,
14+
Message,
15+
Task,
16+
TaskArtifactUpdateEvent,
17+
TaskState,
18+
TaskStatus,
19+
TextPart,
20+
Part,
21+
)
22+
from a2a.utils.errors import ServerError, UnsupportedOperationError
23+
24+
# --- Helper Data ---
25+
TEXT_PART_DATA: dict[str, Any] = {'type': 'text', 'text': 'Hello'}
26+
27+
MINIMAL_MESSAGE_USER: dict[str, Any] = {
28+
'role': 'user',
29+
'parts': [TEXT_PART_DATA],
30+
'messageId': 'msg-123',
31+
'type': 'message',
32+
}
33+
34+
MINIMAL_TASK_STATUS: dict[str, Any] = {'state': 'submitted'}
35+
36+
MINIMAL_TASK: dict[str, Any] = {
37+
'id': 'task-abc',
38+
'contextId': 'session-xyz',
39+
'status': MINIMAL_TASK_STATUS,
40+
'type': 'task',
41+
}
42+
43+
# Test create_task_obj
44+
def test_create_task_obj():
45+
message = Message(**MINIMAL_MESSAGE_USER)
46+
send_params = MessageSendParams(message=message)
47+
48+
task = create_task_obj(send_params)
49+
assert task.id is not None
50+
assert task.contextId == message.contextId
51+
assert task.status.state == TaskState.submitted
52+
assert len(task.history) == 1
53+
assert task.history[0] == message
54+
55+
56+
# Test append_artifact_to_task
57+
def test_append_artifact_to_task():
58+
# Prepare base task
59+
task = Task(**MINIMAL_TASK)
60+
assert task.id == 'task-abc'
61+
assert task.contextId == 'session-xyz'
62+
assert task.status.state == TaskState.submitted
63+
assert task.history is None
64+
assert task.artifacts is None
65+
assert task.metadata is None
66+
67+
# Prepare appending artifact and event
68+
artifact_1 = Artifact(
69+
artifactId="artifact-123", parts=[Part(root=TextPart(text="Hello"))]
70+
)
71+
append_event_1 = TaskArtifactUpdateEvent(artifact=artifact_1, append=False, taskId="123", contextId="123")
72+
73+
# Test adding a new artifact (not appending)
74+
append_artifact_to_task(task, append_event_1)
75+
assert len(task.artifacts) == 1
76+
assert task.artifacts[0].artifactId == "artifact-123"
77+
assert task.artifacts[0].name == None
78+
assert len(task.artifacts[0].parts) == 1
79+
assert task.artifacts[0].parts[0].root.text == "Hello"
80+
81+
# Test replacing the artifact
82+
artifact_2 = Artifact(
83+
artifactId="artifact-123", name="updated name", parts=[Part(root=TextPart(text="Updated"))]
84+
)
85+
append_event_2 = TaskArtifactUpdateEvent(artifact=artifact_2, append=False, taskId="123", contextId="123")
86+
append_artifact_to_task(task, append_event_2)
87+
assert len(task.artifacts) == 1 # Should still have one artifact
88+
assert task.artifacts[0].artifactId == "artifact-123"
89+
assert task.artifacts[0].name == "updated name"
90+
assert len(task.artifacts[0].parts) == 1
91+
assert task.artifacts[0].parts[0].root.text == "Updated"
92+
93+
# Test appending parts to an existing artifact
94+
artifact_with_parts = Artifact(
95+
artifactId="artifact-123", parts=[Part(root=TextPart(text="Part 2"))]
96+
)
97+
append_event_3 = TaskArtifactUpdateEvent(artifact=artifact_with_parts, append=True, taskId="123", contextId="123")
98+
append_artifact_to_task(task, append_event_3)
99+
assert len(task.artifacts[0].parts) == 2
100+
assert task.artifacts[0].parts[0].root.text == "Updated"
101+
assert task.artifacts[0].parts[1].root.text == "Part 2"
102+
103+
# Test adding another new artifact
104+
another_artifact_with_parts = Artifact(
105+
artifactId="new_artifact", parts=[Part(root=TextPart(text="new artifact Part 1"))]
106+
)
107+
append_event_4 = TaskArtifactUpdateEvent(artifact=another_artifact_with_parts, append=False, taskId="123", contextId="123")
108+
append_artifact_to_task(task, append_event_4)
109+
assert len(task.artifacts) == 2
110+
assert task.artifacts[0].artifactId == "artifact-123"
111+
assert task.artifacts[1].artifactId == "new_artifact"
112+
assert len(task.artifacts[0].parts) == 2
113+
assert len(task.artifacts[1].parts) == 1
114+
115+
# Test appending part to a task that does not have a matching artifact
116+
non_existing_artifact_with_parts = Artifact(
117+
artifactId="artifact-456", parts=[Part(root=TextPart(text="Part 1"))]
118+
)
119+
append_event_5 = TaskArtifactUpdateEvent(artifact=non_existing_artifact_with_parts, append=True, taskId="123", contextId="123")
120+
append_artifact_to_task(task, append_event_5)
121+
assert len(task.artifacts) == 2
122+
assert len(task.artifacts[0].parts) == 2
123+
assert len(task.artifacts[1].parts) == 1
124+
125+
# Test build_text_artifact
126+
def test_build_text_artifact():
127+
artifact_id = "text_artifact"
128+
text = "This is a sample text"
129+
artifact = build_text_artifact(text, artifact_id)
130+
131+
assert artifact.artifactId == artifact_id
132+
assert len(artifact.parts) == 1
133+
assert artifact.parts[0].root.text == text
134+
135+
136+
# Test validate decorator
137+
def test_validate_decorator():
138+
class TestClass:
139+
condition = True
140+
141+
@validate(lambda self: self.condition, "Condition not met")
142+
def test_method(self):
143+
return "Success"
144+
145+
obj = TestClass()
146+
147+
# Test passing condition
148+
assert obj.test_method() == "Success"
149+
150+
# Test failing condition
151+
obj.condition = False
152+
with pytest.raises(ServerError) as exc_info:
153+
obj.test_method()
154+
assert "Condition not met" in str(exc_info.value)

tests/utils/test_telemetry.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import pytest
2+
import types
3+
import asyncio
4+
from unittest import mock
5+
from a2a.utils.telemetry import trace_function, trace_class
6+
7+
@pytest.fixture
8+
def mock_span():
9+
span = mock.MagicMock()
10+
return span
11+
12+
@pytest.fixture
13+
def mock_tracer(mock_span):
14+
tracer = mock.MagicMock()
15+
tracer.start_as_current_span.return_value.__enter__.return_value = mock_span
16+
tracer.start_as_current_span.return_value.__exit__.return_value = False
17+
return tracer
18+
19+
@pytest.fixture(autouse=True)
20+
def patch_trace_get_tracer(mock_tracer):
21+
with mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer):
22+
yield
23+
24+
def test_trace_function_sync_success(mock_span):
25+
@trace_function
26+
def foo(x, y):
27+
return x + y
28+
29+
result = foo(2, 3)
30+
assert result == 5
31+
mock_span.set_status.assert_called()
32+
mock_span.set_status.assert_any_call(mock.ANY)
33+
mock_span.record_exception.assert_not_called()
34+
35+
def test_trace_function_sync_exception(mock_span):
36+
@trace_function
37+
def bar():
38+
raise ValueError("fail")
39+
40+
with pytest.raises(ValueError):
41+
bar()
42+
mock_span.record_exception.assert_called()
43+
mock_span.set_status.assert_any_call(mock.ANY, description="fail")
44+
45+
def test_trace_function_sync_attribute_extractor_called(mock_span):
46+
called = {}
47+
def attr_extractor(span, args, kwargs, result, exception):
48+
called['called'] = True
49+
assert span is mock_span
50+
assert exception is None
51+
assert result == 42
52+
53+
@trace_function(attribute_extractor=attr_extractor)
54+
def foo():
55+
return 42
56+
57+
foo()
58+
assert called['called']
59+
60+
def test_trace_function_sync_attribute_extractor_error_logged(mock_span):
61+
with mock.patch("a2a.utils.telemetry.logger") as logger:
62+
def attr_extractor(span, args, kwargs, result, exception):
63+
raise RuntimeError("attr fail")
64+
65+
@trace_function(attribute_extractor=attr_extractor)
66+
def foo():
67+
return 1
68+
69+
foo()
70+
logger.error.assert_any_call(mock.ANY)
71+
72+
@pytest.mark.asyncio
73+
async def test_trace_function_async_success(mock_span):
74+
@trace_function
75+
async def foo(x):
76+
await asyncio.sleep(0)
77+
return x * 2
78+
79+
result = await foo(4)
80+
assert result == 8
81+
mock_span.set_status.assert_called()
82+
mock_span.record_exception.assert_not_called()
83+
84+
@pytest.mark.asyncio
85+
async def test_trace_function_async_exception(mock_span):
86+
@trace_function
87+
async def bar():
88+
await asyncio.sleep(0)
89+
raise RuntimeError("async fail")
90+
91+
with pytest.raises(RuntimeError):
92+
await bar()
93+
mock_span.record_exception.assert_called()
94+
mock_span.set_status.assert_any_call(mock.ANY, description="async fail")
95+
96+
@pytest.mark.asyncio
97+
async def test_trace_function_async_attribute_extractor_called(mock_span):
98+
called = {}
99+
def attr_extractor(span, args, kwargs, result, exception):
100+
called['called'] = True
101+
assert exception is None
102+
assert result == 99
103+
104+
@trace_function(attribute_extractor=attr_extractor)
105+
async def foo():
106+
return 99
107+
108+
await foo()
109+
assert called['called']
110+
111+
def test_trace_function_with_args_and_attributes(mock_span):
112+
@trace_function(span_name="custom.span", attributes={"foo": "bar"})
113+
def foo():
114+
return 1
115+
116+
foo()
117+
mock_span.set_attribute.assert_any_call("foo", "bar")
118+
119+
def test_trace_class_exclude_list(mock_span):
120+
@trace_class(exclude_list=["skip_me"])
121+
class MyClass:
122+
def a(self): return "a"
123+
def skip_me(self): return "skip"
124+
def __str__(self): return "str"
125+
126+
obj = MyClass()
127+
assert obj.a() == "a"
128+
assert obj.skip_me() == "skip"
129+
# Only 'a' is traced, not 'skip_me' or dunder
130+
assert hasattr(obj.a, "__wrapped__")
131+
assert not hasattr(obj.skip_me, "__wrapped__")
132+
133+
def test_trace_class_include_list(mock_span):
134+
@trace_class(include_list=["only_this"])
135+
class MyClass:
136+
def only_this(self): return "yes"
137+
def not_this(self): return "no"
138+
139+
obj = MyClass()
140+
assert obj.only_this() == "yes"
141+
assert obj.not_this() == "no"
142+
assert hasattr(obj.only_this, "__wrapped__")
143+
assert not hasattr(obj.not_this, "__wrapped__")
144+
145+
def test_trace_class_dunder_not_traced(mock_span):
146+
@trace_class()
147+
class MyClass:
148+
def __init__(self): self.x = 1
149+
def foo(self): return "foo"
150+
151+
obj = MyClass()
152+
assert obj.foo() == "foo"
153+
assert hasattr(obj.foo, "__wrapped__")
154+
assert hasattr(obj, "x")

0 commit comments

Comments
 (0)