Skip to content

Commit 0dbb98e

Browse files
chore: redesign notification hooks for better ergonomics pr feedback
1 parent 59e6dba commit 0dbb98e

File tree

4 files changed

+197
-18
lines changed

4 files changed

+197
-18
lines changed

dreadnode/agent/agent.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_total_usage_from_events,
3131
)
3232
from dreadnode.agent.hooks import Hook, retry_with_feedback
33+
from dreadnode.agent.hooks.notification import NotificationBackend, TerminalNotificationBackend
3334
from dreadnode.agent.reactions import (
3435
Continue,
3536
Fail,
@@ -62,6 +63,16 @@
6263
CommitBehavior = t.Literal["always", "on-success"]
6364

6465

66+
async def _safe_send(
67+
backend: NotificationBackend, event: AgentEvent, message: str
68+
) -> None:
69+
"""Send notification with error handling."""
70+
try:
71+
await backend.send(event, message)
72+
except Exception: # noqa: BLE001
73+
logger.exception(f"Notification failed for {event.__class__.__name__}")
74+
75+
6576
class AgentWarning(UserWarning):
6677
"""Warning raised when an agent is used in a way that may not be safe or intended."""
6778

@@ -111,6 +122,24 @@ class Agent(Model):
111122
assert_scores: list[str] | t.Literal[True] = Field(default_factory=list)
112123
"""Scores to ensure are truthy, otherwise the agent task is marked as failed."""
113124

125+
notifications: t.Annotated[
126+
bool | NotificationBackend | None, SkipValidation
127+
] = Config(default=None, repr=False)
128+
"""
129+
Enable notifications.
130+
- True: Uses TerminalNotificationBackend (stderr output)
131+
- NotificationBackend instance: Uses custom backend
132+
- None/False: Disabled
133+
"""
134+
notification_events: list[type[AgentEvent]] | t.Literal["all"] = Config(
135+
default="all", repr=False
136+
)
137+
"""Which event types to notify on. Defaults to all events."""
138+
notification_formatter: t.Annotated[
139+
t.Callable[[AgentEvent], str] | None, SkipValidation
140+
] = Config(default=None, repr=False)
141+
"""Custom formatter for notification messages. If None, uses event's default representation."""
142+
114143
_generator: rg.Generator | None = PrivateAttr(None, init=False)
115144

116145
@field_validator("tools", mode="before")
@@ -129,6 +158,49 @@ def validate_tools(cls, value: t.Any) -> t.Any:
129158

130159
return tools
131160

161+
def model_post_init(self, context: t.Any) -> None:
162+
super().model_post_init(context)
163+
164+
# Auto-inject notification hook if enabled
165+
if self.notifications:
166+
backend = (
167+
self.notifications
168+
if isinstance(self.notifications, NotificationBackend)
169+
else TerminalNotificationBackend()
170+
)
171+
172+
self.hooks.append(
173+
self._create_notification_hook(
174+
backend,
175+
self.notification_events,
176+
self.notification_formatter,
177+
)
178+
)
179+
180+
def _create_notification_hook(
181+
self,
182+
backend: NotificationBackend,
183+
events: list[type[AgentEvent]] | t.Literal["all"],
184+
formatter: t.Callable[[AgentEvent], str] | None,
185+
) -> Hook:
186+
"""Create a notification hook that delegates formatting to events."""
187+
import asyncio
188+
189+
async def notification_hook(event: AgentEvent) -> None:
190+
# Filter events
191+
if events != "all" and not any(isinstance(event, et) for et in events):
192+
return
193+
194+
# Use custom formatter if provided, otherwise delegate to event
195+
message = formatter(event) if formatter else event.format_notification()
196+
197+
# Fire and forget - don't block agent execution
198+
_ = asyncio.create_task(_safe_send(backend, event, message)) # noqa: RUF006
199+
200+
return
201+
202+
return notification_hook
203+
132204
def __repr__(self) -> str:
133205
description = shorten_string(self.description or "", 50)
134206

dreadnode/agent/events.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,22 @@ def format_as_panel(self, *, truncate: bool = False) -> Panel: # noqa: ARG002
115115
border_style="dim",
116116
)
117117

118+
def format_notification(self) -> str:
119+
"""
120+
Format this event as a human-readable notification message.
121+
Override in subclasses for custom formatting.
122+
"""
123+
return f"{self.__class__.__name__}"
124+
118125
def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult:
119126
yield self.format_as_panel()
120127

121128

122129
@dataclass
123130
class AgentStart(AgentEvent):
131+
def format_notification(self) -> str:
132+
return f"Starting agent: {self.agent.name}"
133+
124134
def format_as_panel(self, *, truncate: bool = False) -> Panel:
125135
return Panel(
126136
format_message(self.messages[0], truncate=truncate),
@@ -158,6 +168,10 @@ def __repr__(self) -> str:
158168
message = f"Message(role={self.message.role}, content='{message_content}', tool_calls={tool_call_count})"
159169
return f"GenerationEnd(message={message})"
160170

171+
def format_notification(self) -> str:
172+
tokens = self.usage.total_tokens if self.usage else "unknown"
173+
return f"Generation complete ({tokens} tokens)"
174+
161175
def format_as_panel(self, *, truncate: bool = False) -> Panel:
162176
cost = round(self.estimated_cost, 6) if self.estimated_cost else ""
163177
usage = str(self.usage) or ""
@@ -173,6 +187,9 @@ def format_as_panel(self, *, truncate: bool = False) -> Panel:
173187

174188
@dataclass
175189
class AgentStalled(AgentEventInStep):
190+
def format_notification(self) -> str:
191+
return "Agent stalled: no tool calls and no stop conditions met"
192+
176193
def format_as_panel(self, *, truncate: bool = False) -> Panel: # noqa: ARG002
177194
return Panel(
178195
Text(
@@ -189,6 +206,9 @@ def format_as_panel(self, *, truncate: bool = False) -> Panel: # noqa: ARG002
189206
class AgentError(AgentEventInStep):
190207
error: BaseException
191208

209+
def format_notification(self) -> str:
210+
return f"Error: {self.error.__class__.__name__}: {self.error!s}"
211+
192212
def format_as_panel(self, *, truncate: bool = False) -> Panel: # noqa: ARG002
193213
return Panel(
194214
repr(self),
@@ -205,6 +225,9 @@ class ToolStart(AgentEventInStep):
205225
def __repr__(self) -> str:
206226
return f"ToolStart(tool_call={self.tool_call})"
207227

228+
def format_notification(self) -> str:
229+
return f"Starting tool: {self.tool_call.name}"
230+
208231
def format_as_panel(self, *, truncate: bool = False) -> Panel:
209232
content: RenderableType
210233
try:
@@ -245,6 +268,10 @@ def __repr__(self) -> str:
245268
message = f"Message(role={self.message.role}, content='{message_content}')"
246269
return f"ToolEnd(tool_call={self.tool_call}, message={message}, stop={self.stop})"
247270

271+
def format_notification(self) -> str:
272+
status = " (requesting stop)" if self.stop else ""
273+
return f"Finished tool: {self.tool_call.name}{status}"
274+
248275
def format_as_panel(self, *, truncate: bool = False) -> Panel:
249276
panel = format_message(self.message, truncate=truncate)
250277
subtitle = f"[dim]{self.tool_call.id}[/dim]"
@@ -294,6 +321,10 @@ class AgentEnd(AgentEvent):
294321
stop_reason: "AgentStopReason"
295322
result: "AgentResult"
296323

324+
def format_notification(self) -> str:
325+
status = "❌ Failed" if self.result.failed else "✅ Finished"
326+
return f"{status}: {self.stop_reason} (steps: {self.result.steps}, tokens: {self.result.usage.total_tokens})"
327+
297328
def format_as_panel(self, *, truncate: bool = False) -> Panel: # noqa: ARG002
298329
res = self.result
299330
status = "[bold red]Failed[/bold red]" if res.failed else "[bold green]Success[/bold green]"

dreadnode/agent/hooks/notification.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from loguru import logger
55

66
if t.TYPE_CHECKING:
7+
import httpx
8+
79
from dreadnode.agent.events import AgentEvent
8-
from dreadnode.agent.reactions import Reaction
910

1011

1112
class NotificationBackend(ABC):
@@ -27,29 +28,46 @@ async def send(self, event: "AgentEvent", message: str) -> None:
2728

2829

2930
class WebhookNotificationBackend(NotificationBackend):
30-
def __init__(self, url: str, headers: dict[str, str] | None = None):
31+
def __init__(self, url: str, headers: dict[str, str] | None = None, timeout: float = 5.0):
3132
self.url = url
3233
self.headers = headers or {}
34+
self.timeout = timeout
35+
self._client: httpx.AsyncClient | None = None
36+
37+
async def __aenter__(self) -> "WebhookNotificationBackend":
38+
import httpx
39+
40+
self._client = httpx.AsyncClient(timeout=self.timeout)
41+
return self
42+
43+
async def __aexit__(self, *args: t.Any) -> None:
44+
if self._client:
45+
await self._client.aclose()
3346

3447
async def send(self, event: "AgentEvent", message: str) -> None:
3548
import httpx
3649

37-
payload = {
50+
if not self._client:
51+
self._client = httpx.AsyncClient(timeout=self.timeout)
52+
53+
payload = self._build_payload(event, message)
54+
await self._client.post(self.url, json=payload, headers=self.headers)
55+
56+
def _build_payload(self, event: "AgentEvent", message: str) -> dict[str, str]:
57+
"""Override this to customize webhook payload."""
58+
return {
3859
"agent": event.agent.name,
3960
"event": event.__class__.__name__,
4061
"message": message,
4162
"timestamp": event.timestamp.isoformat(),
4263
}
4364

44-
async with httpx.AsyncClient() as client:
45-
await client.post(self.url, json=payload, headers=self.headers)
46-
4765

4866
def notify(
4967
event_type: "type[AgentEvent] | t.Callable[[AgentEvent], bool]",
50-
message: str | t.Callable[["AgentEvent"], str],
68+
message: str | t.Callable[["AgentEvent"], str] | None = None,
5169
backend: NotificationBackend | None = None,
52-
) -> t.Callable[["AgentEvent"], t.Awaitable["Reaction | None"]]:
70+
) -> t.Callable[["AgentEvent"], t.Awaitable[None]]:
5371
"""
5472
Create a notification hook that sends notifications when events occur.
5573
@@ -58,7 +76,8 @@ def notify(
5876
5977
Args:
6078
event_type: Event type to trigger on, or predicate function
61-
message: Static message or callable that generates message from event
79+
message: Static message or callable that generates message from event.
80+
If None, uses event.format_notification()
6281
backend: Notification backend (defaults to terminal output)
6382
6483
Returns:
@@ -73,6 +92,7 @@ def notify(
7392
agent = Agent(
7493
name="analyzer",
7594
hooks=[
95+
notify(ToolStart), # Uses default formatting
7696
notify(
7797
ToolStart,
7898
lambda e: f"Starting tool: {e.tool_name}",
@@ -83,7 +103,7 @@ def notify(
83103
"""
84104
notification_backend = backend or TerminalNotificationBackend()
85105

86-
async def notification_hook(event: "AgentEvent") -> "Reaction | None":
106+
async def notification_hook(event: "AgentEvent") -> None:
87107
should_notify = False
88108

89109
if isinstance(event_type, type):
@@ -92,15 +112,19 @@ async def notification_hook(event: "AgentEvent") -> "Reaction | None":
92112
should_notify = event_type(event)
93113

94114
if not should_notify:
95-
return None
115+
return
96116

97-
msg = message(event) if callable(message) else message
117+
# Use custom message if provided, otherwise delegate to event
118+
if message is None:
119+
msg = event.format_notification()
120+
else:
121+
msg = message(event) if callable(message) else message
98122

99123
try:
100124
await notification_backend.send(event, msg)
101125
except Exception: # noqa: BLE001
102126
logger.exception("Notification hook failed")
103127

104-
return None
128+
return
105129

106130
return notification_hook

tests/test_notification_hook.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,16 @@ async def test_terminal_notification_backend(mock_event: AgentEvent) -> None:
6464

6565

6666
async def test_webhook_notification_backend(mock_event: AgentEvent) -> None:
67+
from unittest.mock import patch
68+
69+
import httpx
70+
6771
mock_client = AsyncMock()
6872
mock_post = AsyncMock()
69-
mock_client.__aenter__.return_value.post = mock_post
73+
mock_client.post = mock_post
7074

7175
backend = WebhookNotificationBackend("https://example.com/webhook")
7276

73-
from unittest.mock import patch
74-
75-
import httpx
76-
7777
with patch.object(httpx, "AsyncClient", return_value=mock_client):
7878
await backend.send(mock_event, "Test notification")
7979

@@ -159,3 +159,55 @@ async def test_notify_hook_handles_backend_failure(mock_event: AgentEvent) -> No
159159

160160
assert reaction is None
161161
mock_logger.assert_called_once_with("Notification hook failed")
162+
163+
164+
async def test_notify_hook_uses_default_formatter(mock_event: AgentEvent) -> None:
165+
backend = MagicMock(spec=NotificationBackend)
166+
backend.send = AsyncMock()
167+
168+
hook = notify(MockEvent, backend=backend)
169+
170+
reaction = await hook(mock_event)
171+
172+
assert reaction is None
173+
backend.send.assert_called_once()
174+
# Check that it used event.format_notification()
175+
call_args = backend.send.call_args[0]
176+
assert call_args[1] == "MockEvent" # Default format_notification returns class name
177+
178+
179+
def test_agent_auto_inject_notifications_terminal() -> None:
180+
from dreadnode.agent import Agent
181+
182+
agent = Agent(name="test", notifications=True)
183+
184+
# Should have auto-injected a notification hook
185+
assert len(agent.hooks) == 1
186+
187+
188+
def test_agent_auto_inject_notifications_custom_backend() -> None:
189+
from dreadnode.agent import Agent
190+
191+
backend = LogNotificationBackend()
192+
agent = Agent(name="test", notifications=backend)
193+
194+
# Should have auto-injected a notification hook
195+
assert len(agent.hooks) == 1
196+
197+
198+
def test_agent_no_notifications_by_default() -> None:
199+
from dreadnode.agent import Agent
200+
201+
agent = Agent(name="test")
202+
203+
# Should not have any auto-injected hooks
204+
assert len(agent.hooks) == 0
205+
206+
207+
def test_agent_notifications_disabled() -> None:
208+
from dreadnode.agent import Agent
209+
210+
agent = Agent(name="test", notifications=False)
211+
212+
# Should not have any auto-injected hooks
213+
assert len(agent.hooks) == 0

0 commit comments

Comments
 (0)