16
16
import time
17
17
from collections import defaultdict
18
18
from concurrent .futures import ThreadPoolExecutor
19
- from typing import Any , Callable , Dict , Generic , List , Optional , TypeVar
19
+ from typing import (
20
+ Any ,
21
+ Callable ,
22
+ Dict ,
23
+ Generic ,
24
+ Optional ,
25
+ TypeVar ,
26
+ )
20
27
from uuid import uuid4
21
28
22
29
from pydantic import BaseModel , ConfigDict , Field
@@ -46,15 +53,17 @@ class ParametrizedCallback(BaseModel, Generic[T]):
46
53
# Callback is of type T if raw is False, otherwise it is of type Any
47
54
callback : Callable [[T | Any ], None ]
48
55
raw : bool
56
+ id : str = Field (default_factory = lambda : str (uuid4 ()))
49
57
50
58
51
59
class BaseConnector (Generic [T ]):
52
60
def __init__ (self , callback_max_workers : int = 4 ):
53
61
self .callback_max_workers = callback_max_workers
54
62
self .logger = logging .getLogger (self .__class__ .__name__ )
55
- self .registered_callbacks : Dict [str , List [ ParametrizedCallback [T ]]] = (
56
- defaultdict (list )
63
+ self .registered_callbacks : Dict [str , Dict [ str , ParametrizedCallback [T ]]] = (
64
+ defaultdict (dict )
57
65
)
66
+ self .callback_id_mapping : Dict [str , tuple [str , ParametrizedCallback [T ]]] = {}
58
67
self .callback_executor = ThreadPoolExecutor (
59
68
max_workers = self .callback_max_workers
60
69
)
@@ -90,7 +99,7 @@ def register_callback(
90
99
callback : Callable [[T | Any ], None ],
91
100
raw : bool = False ,
92
101
** kwargs : Any ,
93
- ) -> None :
102
+ ) -> str :
94
103
"""Implements register callback.
95
104
96
105
Registers a callback to be called when a message is received from a source.
@@ -100,9 +109,31 @@ def register_callback(
100
109
Raises:
101
110
ConnectorException: If the callback cannot be registered.
102
111
"""
103
- self .registered_callbacks [source ].append (
104
- ParametrizedCallback (callback = callback , raw = raw )
112
+ parametrized_callback = ParametrizedCallback [T ](callback = callback , raw = raw )
113
+ self .registered_callbacks [source ][parametrized_callback .id ] = (
114
+ parametrized_callback
115
+ )
116
+ self .callback_id_mapping [parametrized_callback .id ] = (
117
+ source ,
118
+ parametrized_callback ,
105
119
)
120
+ return parametrized_callback .id
121
+
122
+ def unregister_callback (self , callback_id : str ) -> None :
123
+ """Unregisters a callback from a source.
124
+
125
+ Args:
126
+ callback_id: The id of the callback to unregister.
127
+
128
+ Raises:
129
+ ConnectorException: If the callback cannot be unregistered.
130
+ """
131
+ if callback_id not in self .callback_id_mapping :
132
+ raise ConnectorException (f"Callback with id { callback_id } not found." )
133
+
134
+ source , _ = self .callback_id_mapping [callback_id ]
135
+ del self .registered_callbacks [source ][callback_id ]
136
+ del self .callback_id_mapping [callback_id ]
106
137
107
138
def _safe_callback_wrapper (self , callback : Callable [[T ], None ], message : T ) -> None :
108
139
"""Safely execute a callback with error handling.
@@ -120,7 +151,7 @@ def general_callback(self, source: str, message: Any) -> None:
120
151
"""General callback for all messages.
121
152
Use through functools.partial to pass source."""
122
153
processed_message = self .general_callback_preprocessor (message )
123
- for parametrized_callback in self .registered_callbacks .get (source , [] ):
154
+ for parametrized_callback in self .registered_callbacks .get (source , {}). values ( ):
124
155
payload = message if parametrized_callback .raw else processed_message
125
156
self .callback_executor .submit (
126
157
self ._safe_callback_wrapper , parametrized_callback .callback , payload
0 commit comments