Skip to content

Commit 28eb194

Browse files
authored
feat: unregister callback (#534)
1 parent 5532c2f commit 28eb194

File tree

1 file changed

+38
-7
lines changed

1 file changed

+38
-7
lines changed

src/rai_core/rai/communication/base_connector.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@
1616
import time
1717
from collections import defaultdict
1818
from concurrent.futures import ThreadPoolExecutor
19-
from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar
19+
from typing import (
20+
Any,
21+
Callable,
22+
Dict,
23+
Generic,
24+
Optional,
25+
TypeVar,
26+
)
2027
from uuid import uuid4
2128

2229
from pydantic import BaseModel, ConfigDict, Field
@@ -46,15 +53,17 @@ class ParametrizedCallback(BaseModel, Generic[T]):
4653
# Callback is of type T if raw is False, otherwise it is of type Any
4754
callback: Callable[[T | Any], None]
4855
raw: bool
56+
id: str = Field(default_factory=lambda: str(uuid4()))
4957

5058

5159
class BaseConnector(Generic[T]):
5260
def __init__(self, callback_max_workers: int = 4):
5361
self.callback_max_workers = callback_max_workers
5462
self.logger = logging.getLogger(self.__class__.__name__)
55-
self.registered_callbacks: Dict[str, List[ParametrizedCallback[T]]] = (
56-
defaultdict(list)
63+
self.registered_callbacks: Dict[str, Dict[str, ParametrizedCallback[T]]] = (
64+
defaultdict(dict)
5765
)
66+
self.callback_id_mapping: Dict[str, tuple[str, ParametrizedCallback[T]]] = {}
5867
self.callback_executor = ThreadPoolExecutor(
5968
max_workers=self.callback_max_workers
6069
)
@@ -90,7 +99,7 @@ def register_callback(
9099
callback: Callable[[T | Any], None],
91100
raw: bool = False,
92101
**kwargs: Any,
93-
) -> None:
102+
) -> str:
94103
"""Implements register callback.
95104
96105
Registers a callback to be called when a message is received from a source.
@@ -100,9 +109,31 @@ def register_callback(
100109
Raises:
101110
ConnectorException: If the callback cannot be registered.
102111
"""
103-
self.registered_callbacks[source].append(
104-
ParametrizedCallback(callback=callback, raw=raw)
112+
parametrized_callback = ParametrizedCallback[T](callback=callback, raw=raw)
113+
self.registered_callbacks[source][parametrized_callback.id] = (
114+
parametrized_callback
115+
)
116+
self.callback_id_mapping[parametrized_callback.id] = (
117+
source,
118+
parametrized_callback,
105119
)
120+
return parametrized_callback.id
121+
122+
def unregister_callback(self, callback_id: str) -> None:
123+
"""Unregisters a callback from a source.
124+
125+
Args:
126+
callback_id: The id of the callback to unregister.
127+
128+
Raises:
129+
ConnectorException: If the callback cannot be unregistered.
130+
"""
131+
if callback_id not in self.callback_id_mapping:
132+
raise ConnectorException(f"Callback with id {callback_id} not found.")
133+
134+
source, _ = self.callback_id_mapping[callback_id]
135+
del self.registered_callbacks[source][callback_id]
136+
del self.callback_id_mapping[callback_id]
106137

107138
def _safe_callback_wrapper(self, callback: Callable[[T], None], message: T) -> None:
108139
"""Safely execute a callback with error handling.
@@ -120,7 +151,7 @@ def general_callback(self, source: str, message: Any) -> None:
120151
"""General callback for all messages.
121152
Use through functools.partial to pass source."""
122153
processed_message = self.general_callback_preprocessor(message)
123-
for parametrized_callback in self.registered_callbacks.get(source, []):
154+
for parametrized_callback in self.registered_callbacks.get(source, {}).values():
124155
payload = message if parametrized_callback.raw else processed_message
125156
self.callback_executor.submit(
126157
self._safe_callback_wrapper, parametrized_callback.callback, payload

0 commit comments

Comments
 (0)