11from abc import ABC , abstractmethod
2- from collections import defaultdict
32from collections .abc import Awaitable , Callable
4- from dataclasses import dataclass , field
5- from inspect import getmro , isclass
6- from types import GenericAlias
7- from typing import Any , Protocol , Self , TypeAliasType , runtime_checkable
3+ from typing import Any , Protocol , Self , runtime_checkable
84
95import anyio
10- import injection
6+ from anyio . abc import TaskGroup
117
128from cq ._core .dispatcher .base import BaseDispatcher , Dispatcher
13-
14- type HandlerType [** P , T ] = type [Handler [P , T ]]
15- type HandlerFactory [** P , T ] = Callable [..., Awaitable [Handler [P , T ]]]
9+ from cq ._core .handler import (
10+ HandlerFactory ,
11+ HandlerManager ,
12+ MultipleHandlerManager ,
13+ SingleHandlerManager ,
14+ )
1615
1716type Listener [T ] = Callable [[T ], Awaitable [Any ]]
1817
19- type BusType [I , O ] = type [Bus [I , O ]]
20-
21-
22- @runtime_checkable
23- class Handler [** P , T ](Protocol ):
24- __slots__ = ()
25-
26- @abstractmethod
27- async def handle (self , * args : P .args , ** kwargs : P .kwargs ) -> T :
28- raise NotImplementedError
29-
3018
3119@runtime_checkable
3220class Bus [I , O ](Dispatcher [I , O ], Protocol ):
@@ -41,27 +29,6 @@ def add_listeners(self, *listeners: Listener[I]) -> Self:
4129 raise NotImplementedError
4230
4331
44- @dataclass (eq = False , frozen = True , slots = True )
45- class SubscriberDecorator [I , O ]:
46- bus_type : BusType [I , O ] | TypeAliasType | GenericAlias
47- injection_module : injection .Module = field (default_factory = injection .mod )
48-
49- def __call__ (self , first_input_type : type [I ], / , * input_types : type [I ]) -> Any :
50- def decorator (wrapped : type [Handler [[I ], O ]]) -> type [Handler [[I ], O ]]:
51- if not isclass (wrapped ) or not issubclass (wrapped , Handler ):
52- raise TypeError (f"`{ wrapped } ` isn't a valid handler." )
53-
54- bus = self .injection_module .find_instance (self .bus_type )
55- factory = self .injection_module .make_async_factory (wrapped )
56-
57- for input_type in (first_input_type , * input_types ):
58- bus .subscribe (input_type , factory )
59-
60- return wrapped
61-
62- return decorator
63-
64-
6532class BaseBus [I , O ](BaseDispatcher [I , O ], Bus [I , O ], ABC ):
6633 __slots__ = ("__listeners" ,)
6734
@@ -75,81 +42,47 @@ def add_listeners(self, *listeners: Listener[I]) -> Self:
7542 self .__listeners .extend (listeners )
7643 return self
7744
78- async def _trigger_listeners (self , input_value : I , / ) -> None :
79- listeners = self .__listeners
80-
81- if not listeners :
82- return
83-
84- async with anyio .create_task_group () as task_group :
85- for listener in listeners :
86- task_group .start_soon (listener , input_value )
87-
88- @staticmethod
89- def _make_handle_function (
90- handler_factory : HandlerFactory [[I ], O ],
91- ) -> Callable [[I ], Awaitable [O ]]:
92- async def handle (input_value : I ) -> O :
93- handler = await handler_factory ()
94- return await handler .handle (input_value )
95-
96- return handle
45+ def _trigger_listeners (self , input_value : I , / , task_group : TaskGroup ) -> None :
46+ for listener in self .__listeners :
47+ task_group .start_soon (listener , input_value )
9748
9849
9950class SimpleBus [I , O ](BaseBus [I , O ]):
100- __slots__ = ("__handlers " ,)
51+ __slots__ = ("__manager " ,)
10152
102- __handlers : dict [ type [ I ], HandlerFactory [[ I ], O ] ]
53+ __manager : HandlerManager [ I , O ]
10354
104- def __init__ (self ) -> None :
55+ def __init__ (self , manager : HandlerManager [ I , O ] | None = None ) -> None :
10556 super ().__init__ ()
106- self .__handlers = {}
57+ self .__manager = manager or SingleHandlerManager ()
10758
10859 async def dispatch (self , input_value : I , / ) -> O :
109- await self ._trigger_listeners (input_value )
110-
111- for input_type in getmro (type (input_value )):
112- if handler_factory := self .__handlers .get (input_type ):
113- break
60+ async with anyio .create_task_group () as task_group :
61+ self ._trigger_listeners (input_value , task_group )
11462
115- else :
116- return NotImplemented
63+ for handler in self . __manager . handlers_from ( type ( input_value )) :
64+ return await self . _invoke_with_middlewares ( handler , input_value )
11765
118- handler = self ._make_handle_function (handler_factory )
119- return await self ._invoke_with_middlewares (handler , input_value )
66+ return NotImplemented
12067
12168 def subscribe (self , input_type : type [I ], factory : HandlerFactory [[I ], O ]) -> Self :
122- if input_type in self .__handlers :
123- raise RuntimeError (
124- f"A handler is already registered for the input type: `{ input_type } `."
125- )
126-
127- self .__handlers [input_type ] = factory
69+ self .__manager .subscribe (input_type , factory )
12870 return self
12971
13072
13173class TaskBus [I ](BaseBus [I , None ]):
132- __slots__ = ("__handlers " ,)
74+ __slots__ = ("__manager " ,)
13375
134- __handlers : dict [ type [ I ], list [ HandlerFactory [[ I ], None ]] ]
76+ __manager : HandlerManager [ I , None ]
13577
136- def __init__ (self ) -> None :
78+ def __init__ (self , manager : HandlerManager [ I , None ] | None = None ) -> None :
13779 super ().__init__ ()
138- self .__handlers = defaultdict ( list )
80+ self .__manager = manager or MultipleHandlerManager ( )
13981
14082 async def dispatch (self , input_value : I , / ) -> None :
141- await self ._trigger_listeners (input_value )
142-
143- for input_type in getmro (type (input_value )):
144- if handler_factories := self .__handlers .get (input_type ):
145- break
146-
147- else :
148- return
149-
15083 async with anyio .create_task_group () as task_group :
151- for handler_factory in handler_factories :
152- handler = self ._make_handle_function ( handler_factory )
84+ self . _trigger_listeners ( input_value , task_group )
85+ for handler in self .__manager . handlers_from ( type ( input_value )):
15386 task_group .start_soon (
15487 self ._invoke_with_middlewares ,
15588 handler ,
@@ -161,5 +94,5 @@ def subscribe(
16194 input_type : type [I ],
16295 factory : HandlerFactory [[I ], None ],
16396 ) -> Self :
164- self .__handlers [ input_type ]. append ( factory )
97+ self .__manager . subscribe ( input_type , factory )
16598 return self
0 commit comments