diff --git a/README.md b/README.md
index 9f4b308..565bb31 100644
--- a/README.md
+++ b/README.md
@@ -470,6 +470,49 @@ await bus.dispatch(DataEvent())
+### ๐ง Middleware System
+
+Add cross-cutting concerns like analytics, error handling, and logging using Django-style middleware:
+
+```python
+from bubus import EventBus, BaseEvent
+from bubus.middleware import EventBusMiddleware
+
+class LoggingMiddleware(EventBusMiddleware):
+ def __call__(self, get_handler_result):
+ async def get_handler_result_wrapped_by_middleware(event: BaseEvent):
+ print(f"Processing {event.event_type}")
+
+ try:
+ result = await get_handler_result(event)
+ print(f"Handler succeeded")
+ return result
+ except Exception as e:
+ print(f"Handler failed: {e}")
+ raise
+
+ return get_handler_result_wrapped_by_middleware
+
+# Create event bus with middleware
+bus = EventBus(middlewares=[LoggingMiddleware()])
+```
+
+**Built-in Middleware:**
+
+```python
+from bubus.middleware import WALEventBusMiddleware
+
+# WAL middleware for event persistence
+bus = EventBus(middlewares=[
+ WALEventBusMiddleware('./events.jsonl')
+])
+
+# Or enable WAL automatically with wal_path parameter
+bus = EventBus(wal_path='./events.jsonl') # Automatically adds WAL middleware
+```
+
+
+
### ๐ Write-Ahead Logging
Persist events automatically to a `jsonl` file for future replay and debugging:
diff --git a/bubus/__init__.py b/bubus/__init__.py
index df6e6e2..baf22d9 100644
--- a/bubus/__init__.py
+++ b/bubus/__init__.py
@@ -2,12 +2,17 @@
from bubus.models import BaseEvent, EventHandler, EventResult, PythonIdentifierStr, PythonIdStr, UUIDStr
from bubus.service import EventBus
+from bubus.middleware import EventBusMiddleware, HandlerStartedAnalyticsEvent, HandlerCompletedAnalyticsEvent, WALEventBusMiddleware
__all__ = [
'EventBus',
'BaseEvent',
'EventResult',
'EventHandler',
+ 'EventBusMiddleware',
+ 'HandlerStartedAnalyticsEvent',
+ 'HandlerCompletedAnalyticsEvent',
+ 'WALEventBusMiddleware',
'UUIDStr',
'PythonIdStr',
'PythonIdentifierStr',
diff --git a/bubus/middleware.py b/bubus/middleware.py
new file mode 100644
index 0000000..1c45368
--- /dev/null
+++ b/bubus/middleware.py
@@ -0,0 +1,175 @@
+"""Middleware system for event bus with Django-style nested function pattern."""
+
+import asyncio
+import traceback
+from collections.abc import Awaitable, Callable
+from datetime import UTC, datetime
+from pathlib import Path
+from typing import TYPE_CHECKING, Any
+
+from bubus.models import BaseEvent, EventHandler, PythonIdStr, get_handler_id, get_handler_name
+
+if TYPE_CHECKING:
+ from bubus.service import EventBus
+
+
+# Type alias for middleware functions
+EventMiddleware = Callable[['EventBus', EventHandler, 'BaseEvent[Any]', Callable[[], Awaitable[Any]]], Awaitable[Any]]
+
+
+class HandlerStartedAnalyticsEvent(BaseEvent[None]):
+ """Analytics event dispatched when a handler starts execution"""
+
+ event_id: str # ID of the event being processed
+ started_at: datetime
+ event_bus_id: str
+ event_bus_name: str
+ handler_id: str
+ handler_name: str
+ handler_class: str
+
+
+class HandlerCompletedAnalyticsEvent(BaseEvent[None]):
+ """Analytics event dispatched when a handler completes execution"""
+
+ event_id: str # ID of the event being processed
+ completed_at: datetime
+ error: Exception | None = None
+ traceback_info: str = ''
+ event_bus_id: str
+ event_bus_name: str
+ handler_id: str
+ handler_name: str
+ handler_class: str
+
+
+class EventBusMiddleware:
+ """Base class for Django-style EventBus middleware"""
+
+ def __call__(self, get_handler_result: Callable[['BaseEvent[Any]'], Awaitable[Any]]) -> Callable[['BaseEvent[Any]'], Awaitable[Any]]:
+ """
+ Django-style middleware pattern.
+
+ Args:
+ get_handler_result: The next middleware in the chain or the actual handler
+
+ Returns:
+ Wrapped function that processes events
+ """
+ async def get_handler_result_wrapped_by_middleware(event: BaseEvent[Any]) -> Any:
+ return await get_handler_result(event)
+
+ return get_handler_result_wrapped_by_middleware
+
+
+class WALEventBusMiddleware(EventBusMiddleware):
+ """Write-Ahead Logging middleware for persisting events to JSONL files"""
+
+ def __init__(self, wal_path: Path | str):
+ self.wal_path = Path(wal_path)
+
+ def __call__(self, get_handler_result: Callable[['BaseEvent[Any]'], Awaitable[Any]]) -> Callable[['BaseEvent[Any]'], Awaitable[Any]]:
+ async def get_handler_result_wrapped_by_middleware(event: BaseEvent[Any]) -> Any:
+ # Just execute the handler and log completed events to WAL
+ # This is a simplified implementation - the original EventBus did more complex WAL handling
+ try:
+ result = await get_handler_result(event)
+
+ # Log completed event to WAL
+ try:
+ self.wal_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Use async I/O if available, otherwise sync
+ try:
+ import anyio
+ async with await anyio.open_file(self.wal_path, 'a', encoding='utf-8') as f:
+ await f.write(event.model_dump_json() + '\n')
+ except ImportError:
+ # Fallback to sync I/O
+ with open(self.wal_path, 'a', encoding='utf-8') as f:
+ f.write(event.model_dump_json() + '\n')
+ except Exception:
+ # Don't let WAL errors break the handler
+ pass
+
+ return result
+ except Exception:
+ # Could log error events here too, but keeping it simple
+ raise
+
+ return get_handler_result_wrapped_by_middleware
+
+
+class AnalyticsEventBusMiddleware(EventBusMiddleware):
+ """Analytics middleware that dispatches analytics events for handler execution"""
+
+ def __init__(self, analytics_bus: 'EventBus'):
+ self.analytics_bus = analytics_bus
+
+ def __call__(self, get_handler_result: Callable[['BaseEvent[Any]'], Awaitable[Any]]) -> Callable[['BaseEvent[Any]'], Awaitable[Any]]:
+ async def get_handler_result_wrapped_by_middleware(event: BaseEvent[Any]) -> Any:
+ # Access event bus and handler info from the event context
+ from bubus.models import get_handler_id, get_handler_name
+ from bubus.service import _current_handler_id_context, inside_handler_context
+
+ # We can access the event bus through event.event_bus
+ event_bus = event.event_bus
+
+ # Get handler information from context
+ handler_id = _current_handler_id_context.get()
+
+ # Get the event result object which contains handler information
+ event_result = None
+ if handler_id and handler_id in event.event_results:
+ event_result = event.event_results[handler_id]
+
+ # Dispatch started analytics event if we have the context
+ if event_result and inside_handler_context.get():
+ started_event = HandlerStartedAnalyticsEvent(
+ event_id=event.event_id,
+ started_at=event_result.started_at or datetime.now(UTC),
+ event_bus_id=event_bus.id,
+ event_bus_name=event_bus.name,
+ handler_id=handler_id,
+ handler_name=event_result.handler_name,
+ handler_class=event_result.handler_class,
+ )
+ self.analytics_bus.dispatch(started_event)
+
+ try:
+ result = await get_handler_result(event)
+
+ # Dispatch completed analytics event
+ if event_result and inside_handler_context.get():
+ completed_event = HandlerCompletedAnalyticsEvent(
+ event_id=event.event_id,
+ completed_at=datetime.now(UTC),
+ error=None,
+ traceback_info='',
+ event_bus_id=event_bus.id,
+ event_bus_name=event_bus.name,
+ handler_id=handler_id,
+ handler_name=event_result.handler_name,
+ handler_class=event_result.handler_class,
+ )
+ self.analytics_bus.dispatch(completed_event)
+
+ return result
+ except Exception as e:
+ # Dispatch completed analytics event with error
+ if event_result and inside_handler_context.get():
+ completed_event = HandlerCompletedAnalyticsEvent(
+ event_id=event.event_id,
+ completed_at=datetime.now(UTC),
+ error=e,
+ traceback_info=traceback.format_exc(),
+ event_bus_id=event_bus.id,
+ event_bus_name=event_bus.name,
+ handler_id=handler_id,
+ handler_name=event_result.handler_name,
+ handler_class=event_result.handler_class,
+ )
+ self.analytics_bus.dispatch(completed_event)
+ raise
+
+ return get_handler_result_wrapped_by_middleware
\ No newline at end of file
diff --git a/bubus/models.py b/bubus/models.py
index b165bc0..2fa7f6d 100644
--- a/bubus/models.py
+++ b/bubus/models.py
@@ -2,6 +2,7 @@
import inspect
import logging
import os
+import traceback
from collections.abc import Awaitable, Callable, Generator
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, Protocol, Self, TypeAlias, cast, runtime_checkable
@@ -923,6 +924,9 @@ def log_tree(
log_eventresult_tree(self, indent, is_last, child_events_by_parent)
+# Analytics events are now in bubus.middleware module
+
+
# Resolve forward references
BaseEvent.model_rebuild()
EventResult.model_rebuild()
diff --git a/bubus/service.py b/bubus/service.py
index 4ca5b7a..5afcefc 100644
--- a/bubus/service.py
+++ b/bubus/service.py
@@ -2,10 +2,11 @@
import contextvars
import inspect
import logging
+import traceback
import warnings
import weakref
from collections import defaultdict, deque
-from collections.abc import Callable
+from collections.abc import Awaitable, Callable
from contextvars import ContextVar
from pathlib import Path
from typing import Any, Literal, TypeVar, cast, overload
@@ -51,6 +52,8 @@ class QueueShutDown(Exception):
EventPatternType = PythonIdentifierStr | Literal['*'] | type['BaseEvent[Any]']
+# EventBusMiddleware will be imported dynamically to avoid circular imports
+
class CleanShutdownQueue(asyncio.Queue[QueueEntryType]):
"""asyncio.Queue subclass that handles shutdown cleanly without warnings."""
@@ -264,6 +267,7 @@ def __init__(
wal_path: Path | str | None = None,
parallel_handlers: bool = False,
max_history_size: int | None = 50, # Keep only 50 events in history
+ middlewares: list[Any] | None = None,
):
self.id = uuid7str()
self.name = name or f'{self.__class__.__name__}_{self.id[-8:]}'
@@ -317,6 +321,13 @@ def __init__(
self.parallel_handlers = parallel_handlers
self.wal_path = Path(wal_path) if wal_path else None
self._on_idle = None
+
+ # Set up middlewares, adding WAL middleware if wal_path is provided
+ self.middlewares = middlewares or []
+ if wal_path:
+ # Import here to avoid circular imports
+ from bubus.middleware import WALEventBusMiddleware
+ self.middlewares.append(WALEventBusMiddleware(wal_path))
# Memory leak prevention settings
self.max_history_size = max_history_size
@@ -949,7 +960,6 @@ async def process_event(self, event: 'BaseEvent[Any]', timeout: float | None = N
await self._execute_handlers(event, handlers=applicable_handlers, timeout=timeout)
await self._default_log_handler(event)
- await self._default_wal_handler(event)
# Mark event as complete if all handlers are done
event.event_mark_complete_if_all_handlers_completed()
@@ -1023,8 +1033,8 @@ async def _execute_handlers(
context = contextvars.copy_context()
for handler_id, handler in applicable_handlers.items():
task = asyncio.create_task(
- self._execute_sync_or_async_handler(event, handler, timeout=timeout),
- name=f'{self}._execute_sync_or_async_handler({event}, {get_handler_name(handler)})',
+ self._execute_handler_with_middlewares(event, handler, timeout=timeout),
+ name=f'{self}._execute_handler_with_middlewares({event}, {get_handler_name(handler)})',
context=context,
)
handler_tasks[handler_id] = (task, handler)
@@ -1034,20 +1044,50 @@ async def _execute_handlers(
try:
await task
except Exception:
- # Error already logged and recorded in _execute_sync_or_async_handler
+ # Error already logged and recorded in _execute_handler_with_middlewares
pass
else:
# otherwise, execute handlers serially, wait until each one completes before moving on to the next
for handler_id, handler in applicable_handlers.items():
try:
- await self._execute_sync_or_async_handler(event, handler, timeout=timeout)
+ await self._execute_handler_with_middlewares(event, handler, timeout=timeout)
except Exception as e:
- # Error already logged and recorded in _execute_sync_or_async_handler
+ # Error already logged and recorded in _execute_handler_with_middlewares
logger.debug(
f'โ {self} Handler {get_handler_name(handler)}#{str(id(handler))[-4:]}({event}) failed with {type(e).__name__}: {e}'
)
pass
+ async def _execute_handler_with_middlewares(
+ self, event: 'BaseEvent[T_EventResultType]', handler: EventHandler, timeout: float | None = None
+ ) -> Any:
+ """Execute a handler through the Django-style middleware chain"""
+ if not self.middlewares:
+ # No middlewares, execute handler directly
+ return await self._execute_sync_or_async_handler(event, handler, timeout)
+
+ # Create Django-style middleware chain by wrapping the handler in middleware layers
+ async def base_handler(event: 'BaseEvent[Any]') -> Any:
+ return await self._execute_sync_or_async_handler(event, handler, timeout)
+
+ # Wrap the handler with each middleware (in reverse order for correct execution)
+ wrapped_handler = base_handler
+ for middleware in reversed(self.middlewares):
+ try:
+ wrapped_handler = middleware(wrapped_handler)
+ except Exception as e:
+ # Log middleware initialization error and re-raise
+ handler_id = get_handler_id(handler, self)
+ logger.exception(
+ f'โ {self} Error initializing middleware {middleware.__class__.__name__} '
+ f'for handler {get_handler_name(handler)}#{handler_id[-4:]}({event}) -> {type(e).__name__}({e})',
+ exc_info=True,
+ )
+ raise
+
+ # Execute the wrapped handler
+ return await wrapped_handler(event)
+
async def _execute_sync_or_async_handler(
self, event: 'BaseEvent[T_EventResultType]', handler: EventHandler, timeout: float | None = None
) -> Any:
@@ -1256,19 +1296,7 @@ async def _default_log_handler(self, event: 'BaseEvent[Any]') -> None:
# )
pass
- async def _default_wal_handler(self, event: 'BaseEvent[Any]') -> None:
- """Persist completed event to WAL file as JSONL"""
-
- if not self.wal_path:
- return None
-
- try:
- event_json = event.model_dump_json() # pyright: ignore[reportUnknownMemberType]
- self.wal_path.parent.mkdir(parents=True, exist_ok=True)
- async with await anyio.open_file(self.wal_path, 'a', encoding='utf-8') as f: # pyright: ignore[reportUnknownMemberType]
- await f.write(event_json + '\n') # pyright: ignore[reportUnknownMemberType]
- except Exception as e:
- logger.error(f'โ {self} Failed to save event {event.event_id} to WAL file: {type(e).__name__} {e}\n{event}')
+ # WAL functionality is now handled by WALEventBusMiddleware
def cleanup_excess_events(self) -> int:
"""
diff --git a/tests/test_middleware.py b/tests/test_middleware.py
new file mode 100644
index 0000000..4e21a93
--- /dev/null
+++ b/tests/test_middleware.py
@@ -0,0 +1,242 @@
+#!/usr/bin/env python3
+"""
+Test script to verify the Django-style middleware functionality.
+"""
+import asyncio
+import traceback
+from datetime import UTC, datetime
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+from bubus import BaseEvent, EventBus
+from bubus.middleware import (
+ EventBusMiddleware,
+ HandlerStartedAnalyticsEvent,
+ HandlerCompletedAnalyticsEvent,
+ WALEventBusMiddleware,
+)
+from bubus.models import get_handler_id, get_handler_name
+
+
+class TestEvent(BaseEvent[str]):
+ message: str
+
+
+class AnalyticsMiddleware(EventBusMiddleware):
+ """Middleware that dispatches analytics events"""
+
+ def __init__(self, analytics_bus: EventBus):
+ self.analytics_bus = analytics_bus
+ super().__init__()
+
+ def __call__(self, get_handler_result):
+ async def get_handler_result_wrapped_by_middleware(event: BaseEvent):
+ # Note: In the Django pattern, we don't have direct access to handler/eventbus
+ # This is a simplified version for testing
+
+ # Simulate analytics event before handler
+ await self.analytics_bus.dispatch(HandlerStartedAnalyticsEvent(
+ event_id=event.event_id,
+ started_at=datetime.now(UTC),
+ event_bus_id="test_bus_id",
+ event_bus_name="TestBus",
+ handler_id="test_handler_id",
+ handler_name="test_handler",
+ handler_class="test_module.TestHandler",
+ ))
+
+ try:
+ result = await get_handler_result(event)
+
+ # Simulate analytics event after successful handler
+ await self.analytics_bus.dispatch(HandlerCompletedAnalyticsEvent(
+ event_id=event.event_id,
+ completed_at=datetime.now(UTC),
+ error=None,
+ traceback_info="",
+ event_bus_id="test_bus_id",
+ event_bus_name="TestBus",
+ handler_id="test_handler_id",
+ handler_name="test_handler",
+ handler_class="test_module.TestHandler",
+ ))
+
+ return result
+ except Exception as e:
+ # Simulate analytics event after failed handler
+ await self.analytics_bus.dispatch(HandlerCompletedAnalyticsEvent(
+ event_id=event.event_id,
+ completed_at=datetime.now(UTC),
+ error=e,
+ traceback_info=traceback.format_exc(),
+ event_bus_id="test_bus_id",
+ event_bus_name="TestBus",
+ handler_id="test_handler_id",
+ handler_name="test_handler",
+ handler_class="test_module.TestHandler",
+ ))
+ raise
+
+ return get_handler_result_wrapped_by_middleware
+
+
+class LoggingMiddleware(EventBusMiddleware):
+ """Simple logging middleware for testing"""
+
+ def __call__(self, get_handler_result):
+ async def get_handler_result_wrapped_by_middleware(event: BaseEvent):
+ print(f"๐ Logging: Processing event {event.event_type}")
+
+ try:
+ result = await get_handler_result(event)
+ print(f"๐ Logging: Handler succeeded")
+ return result
+ except Exception as e:
+ print(f"๐ Logging: Handler failed with error: {e}")
+ raise
+
+ return get_handler_result_wrapped_by_middleware
+
+
+def analytics_handler(event: HandlerStartedAnalyticsEvent | HandlerCompletedAnalyticsEvent) -> None:
+ """Handle analytics events"""
+ print(f"๐ Analytics: {event.event_type} - Handler: {event.handler_name} in {event.event_bus_name}")
+ if isinstance(event, HandlerCompletedAnalyticsEvent) and event.error:
+ print(f" Error: {event.error}")
+
+
+def test_handler(event: TestEvent) -> str:
+ """Simple test handler"""
+ print(f"๐ง Handler processing: {event.message}")
+ return f"Processed: {event.message}"
+
+
+def failing_handler(event: TestEvent) -> str:
+ """Handler that always fails"""
+ print(f"๐ฅ Failing handler processing: {event.message}")
+ raise ValueError("This handler always fails!")
+
+
+async def test_basic_middleware():
+ """Test basic middleware functionality"""
+ print("๐งช Testing basic middleware functionality...")
+
+ # Create analytics bus
+ analytics_bus = EventBus(name='AnalyticsBus')
+ analytics_bus.on('*', analytics_handler)
+
+ # Create event bus with middleware
+ event_bus = EventBus(
+ name='TestEventBus',
+ middlewares=[
+ LoggingMiddleware(),
+ AnalyticsMiddleware(analytics_bus),
+ ],
+ )
+
+ # Register handlers
+ event_bus.on(TestEvent, test_handler)
+ event_bus.on(TestEvent, failing_handler)
+
+ # Test with successful event
+ print("\nโ
Testing event processing...")
+ test_event = TestEvent(message="Hello, Django-style middleware!")
+ completed_event = await event_bus.dispatch(test_event)
+ print(f"Event completed: {completed_event.event_status}")
+
+ # Wait for analytics to process
+ await asyncio.sleep(0.1)
+
+ # Stop the buses
+ await event_bus.stop()
+ await analytics_bus.stop()
+
+ print("โ
Basic middleware test completed successfully!")
+
+
+async def test_wal_middleware():
+ """Test WAL middleware functionality"""
+ print("\n๐งช Testing WAL middleware...")
+
+ with TemporaryDirectory() as tmp_dir:
+ wal_path = Path(tmp_dir) / "test_events.jsonl"
+
+ # Create event bus with WAL enabled
+ event_bus = EventBus(
+ name='WALTestBus',
+ wal_path=wal_path, # This should automatically add WALEventBusMiddleware
+ )
+
+ # Register a handler
+ event_bus.on(TestEvent, test_handler)
+
+ # Dispatch an event
+ test_event = TestEvent(message="WAL test message")
+ await event_bus.dispatch(test_event)
+
+ # Wait for processing
+ await asyncio.sleep(0.1)
+
+ # Check if WAL file was created and contains the event
+ if wal_path.exists():
+ content = wal_path.read_text()
+ if "WAL test message" in content:
+ print("โ
WAL middleware working correctly!")
+ else:
+ print("โ WAL file exists but doesn't contain expected content")
+ print(f"Content: {content}")
+ else:
+ print("โ WAL file was not created")
+
+ await event_bus.stop()
+
+
+async def test_custom_wal_middleware():
+ """Test using WALEventBusMiddleware explicitly"""
+ print("\n๐งช Testing custom WAL middleware...")
+
+ with TemporaryDirectory() as tmp_dir:
+ wal_path = Path(tmp_dir) / "custom_wal.jsonl"
+
+ # Create event bus with explicit WAL middleware
+ event_bus = EventBus(
+ name='CustomWALBus',
+ middlewares=[WALEventBusMiddleware(wal_path)],
+ )
+
+ # Register a handler
+ event_bus.on(TestEvent, test_handler)
+
+ # Dispatch an event
+ test_event = TestEvent(message="Custom WAL test")
+ await event_bus.dispatch(test_event)
+
+ # Wait for processing
+ await asyncio.sleep(0.1)
+
+ # Check WAL file
+ if wal_path.exists():
+ content = wal_path.read_text()
+ if "Custom WAL test" in content:
+ print("โ
Custom WAL middleware working correctly!")
+ else:
+ print("โ WAL file exists but doesn't contain expected content")
+ else:
+ print("โ Custom WAL file was not created")
+
+ await event_bus.stop()
+
+
+async def main():
+ """Run all middleware tests"""
+ print("๐งช Testing Django-style middleware functionality...")
+
+ await test_basic_middleware()
+ await test_wal_middleware()
+ await test_custom_wal_middleware()
+
+ print("\nโ
All middleware tests completed successfully!")
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
\ No newline at end of file