Skip to content

Commit d91880b

Browse files
committed
ADD: Add exception handler to live client
1 parent 2ab8e48 commit d91880b

File tree

6 files changed

+122
-33
lines changed

6 files changed

+122
-33
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
### Enhancements
66
- Added `symbology_map` property to `Live` client
7+
- Changed `Live.add_callback` and `Live.add_stream` to accept an exception callback
8+
- Changed `Live.add_callback` and `Live.add_stream` `func` parameter to `record_callback`
79

810
## 0.14.1 - 2023-06-16
911

databento/live/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Union
1+
from typing import Callable, Union
22

33
import databento_dbn
44

55

66
AUTH_TIMEOUT_SECONDS: float = 2
77
CONNECT_TIMEOUT_SECONDS: float = 5
88

9+
910
DBNRecord = Union[
1011
databento_dbn.MBOMsg,
1112
databento_dbn.MBP1Msg,
@@ -19,3 +20,6 @@
1920
databento_dbn.SystemMsg,
2021
databento_dbn.ErrorMsg,
2122
]
23+
24+
RecordCallback = Callable[[DBNRecord], None]
25+
ExceptionCallback = Callable[[Exception], None]

databento/live/client.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections.abc import Iterable
99
from concurrent import futures
1010
from numbers import Number
11-
from typing import IO, Callable
11+
from typing import IO
1212

1313
import databento_dbn
1414

@@ -23,6 +23,8 @@
2323
from databento.common.validation import validate_enum
2424
from databento.common.validation import validate_semantic_string
2525
from databento.live import DBNRecord
26+
from databento.live import ExceptionCallback
27+
from databento.live import RecordCallback
2628
from databento.live.session import DEFAULT_REMOTE_PORT
2729
from databento.live.session import DBNQueue
2830
from databento.live.session import Session
@@ -31,9 +33,6 @@
3133

3234

3335
logger = logging.getLogger(__name__)
34-
35-
UserCallback = Callable[[DBNRecord], None]
36-
3736
DEFAULT_QUEUE_SIZE = 2048
3837

3938

@@ -89,8 +88,8 @@ def __init__(
8988
self._dbn_queue: DBNQueue = DBNQueue(maxsize=DEFAULT_QUEUE_SIZE)
9089
self._metadata: SessionMetadata = SessionMetadata()
9190
self._symbology_map: dict[int, str | int] = {}
92-
self._user_callbacks: list[UserCallback] = [self._map_symbol]
93-
self._user_streams: list[IO[bytes]] = []
91+
self._user_callbacks: dict[RecordCallback, ExceptionCallback | None] = {}
92+
self._user_streams: dict[IO[bytes], ExceptionCallback | None] = {}
9493

9594
def factory() -> _SessionProtocol:
9695
return _SessionProtocol(
@@ -269,15 +268,19 @@ def ts_out(self) -> bool:
269268

270269
def add_callback(
271270
self,
272-
func: UserCallback,
271+
record_callback: RecordCallback,
272+
exception_callback: ExceptionCallback | None = None,
273273
) -> None:
274274
"""
275275
Add a callback for handling records.
276276
277277
Parameters
278278
----------
279-
func : Callable[[DBNRecord], None]
279+
record_callback : Callable[[DBNRecord], None]
280280
A callback to register for handling live records as they arrive.
281+
exception_callback : Callable[[Exception], None], optional
282+
An error handling callback to process exceptions that are raised
283+
in `record_callback`.
281284
282285
Raises
283286
------
@@ -289,20 +292,31 @@ def add_callback(
289292
Live.add_stream
290293
291294
"""
292-
if not callable(func):
293-
raise ValueError(f"{func} is not callable")
294-
callback_name = getattr(func, "__name__", str(func))
295+
if not callable(record_callback):
296+
raise ValueError(f"{record_callback} is not callable")
297+
298+
if exception_callback is not None and not callable(exception_callback):
299+
raise ValueError(f"{exception_callback} is not callable")
300+
301+
callback_name = getattr(record_callback, "__name__", str(record_callback))
295302
logger.info("adding user callback %s", callback_name)
296-
self._user_callbacks.append(func)
303+
self._user_callbacks[record_callback] = exception_callback
297304

298-
def add_stream(self, stream: IO[bytes]) -> None:
305+
def add_stream(
306+
self,
307+
stream: IO[bytes],
308+
exception_callback: ExceptionCallback | None = None,
309+
) -> None:
299310
"""
300311
Add an IO stream to write records to.
301312
302313
Parameters
303314
----------
304315
stream : IO[bytes]
305316
The IO stream to write to when handling live records as they arrive.
317+
exception_callback : Callable[[Exception], None], optional
318+
An error handling callback to process exceptions that are raised
319+
when writing to the stream.
306320
307321
Raises
308322
------
@@ -320,11 +334,14 @@ def add_stream(self, stream: IO[bytes]) -> None:
320334
if not hasattr(stream, "writable") or not stream.writable():
321335
raise ValueError(f"{type(stream).__name__} is not a writable stream")
322336

337+
if exception_callback is not None and not callable(exception_callback):
338+
raise ValueError(f"{exception_callback} is not callable")
339+
323340
stream_name = getattr(stream, "name", str(stream))
324341
logger.info("adding user stream %s", stream_name)
325342
if self.metadata is not None:
326343
stream.write(bytes(self.metadata))
327-
self._user_streams.append(stream)
344+
self._user_streams[stream] = exception_callback
328345

329346
def start(
330347
self,

databento/live/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def eof_received(self) -> bool | None:
186186
asycnio.BufferedProtocol.eof_received
187187
188188
"""
189-
logger.info("received EOF file from remote")
189+
logger.info("received EOF from remote")
190190
return super().eof_received()
191191

192192
def get_buffer(self, sizehint: int) -> bytearray:

databento/live/session.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@
2020
from databento.live import AUTH_TIMEOUT_SECONDS
2121
from databento.live import CONNECT_TIMEOUT_SECONDS
2222
from databento.live import DBNRecord
23+
from databento.live import ExceptionCallback
24+
from databento.live import RecordCallback
2325
from databento.live.protocol import DatabentoLiveProtocol
2426

2527

2628
logger = logging.getLogger(__name__)
2729

2830

29-
UserCallback = Callable[[DBNRecord], None]
3031
DEFAULT_REMOTE_PORT = 13000
3132

3233

@@ -122,8 +123,8 @@ def __init__(
122123
api_key: str,
123124
dataset: Dataset | str,
124125
dbn_queue: DBNQueue,
125-
user_callbacks: list[UserCallback],
126-
user_streams: list[IO[bytes]],
126+
user_callbacks: dict[RecordCallback, ExceptionCallback | None],
127+
user_streams: dict[IO[bytes], ExceptionCallback | None],
127128
loop: asyncio.AbstractEventLoop,
128129
metadata: SessionMetadata,
129130
ts_out: bool = False,
@@ -140,22 +141,28 @@ def __init__(
140141
def received_metadata(self, metadata: databento_dbn.Metadata) -> None:
141142
if not self._metadata:
142143
self._metadata.data = metadata
143-
for stream in self._user_streams:
144-
task = self._loop.create_task(self._stream_task(stream, metadata))
144+
for stream, exc_callback in self._user_streams.items():
145+
task = self._loop.create_task(
146+
self._stream_task(stream, metadata, exc_callback),
147+
)
145148
task.add_done_callback(self._tasks.remove)
146149
self._tasks.add(task)
147150
else:
148151
self._metadata.check(metadata)
149152
return super().received_metadata(metadata)
150153

151154
def received_record(self, record: DBNRecord) -> None:
152-
for callback in self._user_callbacks:
153-
task = self._loop.create_task(self._callback_task(callback, record))
155+
for callback, exc_callback in self._user_callbacks.items():
156+
task = self._loop.create_task(
157+
self._callback_task(callback, record, exc_callback),
158+
)
154159
task.add_done_callback(self._tasks.remove)
155160
self._tasks.add(task)
156161

157-
for stream in self._user_streams:
158-
task = self._loop.create_task(self._stream_task(stream, record))
162+
for stream, exc_callback in self._user_streams.items():
163+
task = self._loop.create_task(
164+
self._stream_task(stream, record, exc_callback),
165+
)
159166
task.add_done_callback(self._tasks.remove)
160167
self._tasks.add(task)
161168

@@ -180,26 +187,29 @@ def received_record(self, record: DBNRecord) -> None:
180187

181188
async def _callback_task(
182189
self,
183-
func: UserCallback,
190+
record_callback: RecordCallback,
184191
record: DBNRecord,
192+
exception_callback: ExceptionCallback | None,
185193
) -> None:
186194
try:
187-
func(record)
195+
record_callback(record)
188196
except Exception as exc:
189197
logger.error(
190198
"error dispatching %s to `%s` callback",
191199
type(record).__name__,
192-
func.__name__,
200+
record_callback.__name__,
193201
exc_info=exc,
194202
)
195-
raise
203+
if exception_callback is not None:
204+
self._loop.call_soon_threadsafe(exception_callback, exc)
196205

197206
async def _stream_task(
198207
self,
199208
stream: IO[bytes],
200209
record: databento_dbn.Metadata | DBNRecord,
210+
exc_callback: ExceptionCallback | None,
201211
) -> None:
202-
has_ts_out = self._metadata and self._metadata.data.ts_out
212+
has_ts_out = self._metadata.data and self._metadata.data.ts_out
203213
try:
204214
stream.write(bytes(record))
205215
if not isinstance(record, databento_dbn.Metadata) and has_ts_out:
@@ -212,7 +222,8 @@ async def _stream_task(
212222
stream_name,
213223
exc_info=exc,
214224
)
215-
raise
225+
if exc_callback is not None:
226+
self._loop.call_soon_threadsafe(exc_callback, exc)
216227

217228
async def wait_for_processing(self) -> None:
218229
while self._tasks:

tests/test_live_client.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,8 @@ def callback(_: object) -> None:
497497

498498
live_client.add_callback(callback)
499499
assert callback in live_client._user_callbacks
500-
assert live_client._user_streams == []
500+
assert live_client._user_callbacks[callback] is None
501+
assert live_client._user_streams == {}
501502

502503

503504
def test_live_add_stream(
@@ -509,7 +510,9 @@ def test_live_add_stream(
509510
stream = BytesIO()
510511

511512
live_client.add_stream(stream)
512-
assert live_client._user_streams == [stream]
513+
assert stream in live_client._user_streams
514+
assert live_client._user_streams[stream] is None
515+
assert live_client._user_callbacks == {}
513516

514517

515518
def test_live_add_stream_invalid(
@@ -1064,3 +1067,55 @@ def test_live_connection_reconnect_cram_failure(
10641067
dataset=Dataset.GLBX_MDP3,
10651068
schema=Schema.MBO,
10661069
)
1070+
1071+
async def test_live_callback_exception_handler(
1072+
live_client: client.Live,
1073+
) -> None:
1074+
"""
1075+
Test exceptions that occur during callbacks are dispatched to the assigned
1076+
exception handler.
1077+
"""
1078+
live_client.subscribe(
1079+
dataset=Dataset.GLBX_MDP3,
1080+
schema=Schema.MBO,
1081+
stype_in=SType.RAW_SYMBOL,
1082+
symbols="TEST",
1083+
)
1084+
1085+
exceptions: list[Exception] = []
1086+
1087+
def callback(_: DBNRecord) -> None:
1088+
raise RuntimeError("this is a test")
1089+
1090+
live_client.add_callback(callback, exceptions.append)
1091+
1092+
live_client.start()
1093+
1094+
await live_client.wait_for_close()
1095+
assert len(exceptions) == 4
1096+
1097+
1098+
async def test_live_stream_exception_handler(
1099+
live_client: client.Live,
1100+
) -> None:
1101+
"""
1102+
Test exceptions that occur during stream writes are dispatched to the
1103+
assigned exception handler.
1104+
"""
1105+
live_client.subscribe(
1106+
dataset=Dataset.GLBX_MDP3,
1107+
schema=Schema.MBO,
1108+
stype_in=SType.RAW_SYMBOL,
1109+
symbols="TEST",
1110+
)
1111+
1112+
exceptions: list[Exception] = []
1113+
1114+
stream = BytesIO()
1115+
live_client.add_stream(stream, exceptions.append)
1116+
stream.close()
1117+
1118+
live_client.start()
1119+
1120+
await live_client.wait_for_close()
1121+
assert len(exceptions) == 5 # extra write from metadata

0 commit comments

Comments
 (0)