Skip to content

Commit 62a93f6

Browse files
committed
MOD: Support file paths in Live.add_stream
1 parent 935e67f commit 62a93f6

File tree

3 files changed

+102
-2
lines changed

3 files changed

+102
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44

55
#### Enhancements
66
- Added `map_symbols` support for DBN data generated by the `Live` client
7+
- Added support for file paths in `Live.add_stream`
78

89
#### Bug fixes
910
- Fixed an issue where `DBNStore.from_bytes` did not rewind seekable buffers
1011
- Fixed an issue where the `DBNStore` would not map symbols with input symbology of `SType.INSTRUMENT_ID`
1112
- Fixed an issue with `DBNStore.request_symbology` when the DBN metadata's start date and end date were the same
13+
- Fixed an issue where closed streams were not removed from a `Live` client on shutdown.
1214

1315
## 0.20.0 - 2023-09-21
1416

databento/live/client.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import asyncio
44
import logging
55
import os
6+
import pathlib
67
import queue
78
import threading
89
from collections.abc import Iterable
910
from concurrent import futures
1011
from numbers import Number
12+
from os import PathLike
1113
from typing import IO
1214

1315
import databento_dbn
@@ -307,15 +309,15 @@ def add_callback(
307309

308310
def add_stream(
309311
self,
310-
stream: IO[bytes],
312+
stream: IO[bytes] | PathLike[str] | str,
311313
exception_callback: ExceptionCallback | None = None,
312314
) -> None:
313315
"""
314316
Add an IO stream to write records to.
315317
316318
Parameters
317319
----------
318-
stream : IO[bytes]
320+
stream : IO[bytes] or PathLike[str] or str
319321
The IO stream to write to when handling live records as they arrive.
320322
exception_callback : Callable[[Exception], None], optional
321323
An error handling callback to process exceptions that are raised
@@ -325,12 +327,17 @@ def add_stream(
325327
------
326328
ValueError
327329
If `stream` is not a writable byte stream.
330+
OSError
331+
If `stream` is not a path to a writeable file.
328332
329333
See Also
330334
--------
331335
Live.add_callback
332336
333337
"""
338+
if isinstance(stream, (str, PathLike)):
339+
stream = pathlib.Path(stream).open("wb")
340+
334341
if not hasattr(stream, "write"):
335342
raise ValueError(f"{type(stream).__name__} does not support write()")
336343

@@ -589,6 +596,19 @@ async def _shutdown(self) -> None:
589596
if self._session is None:
590597
return
591598
await self._session.wait_for_close()
599+
600+
to_remove = []
601+
for stream in self._user_streams:
602+
stream_name = getattr(stream, "name", str(stream))
603+
if stream.closed:
604+
logger.info("removing closed user stream %s", stream_name)
605+
to_remove.append(stream)
606+
else:
607+
stream.flush()
608+
609+
for key in to_remove:
610+
self._user_streams.pop(key)
611+
592612
self._symbology_map.clear()
593613

594614
def _map_symbol(self, record: DBNRecord) -> None:

tests/test_live_client.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def test_live_start_twice(
309309
with pytest.raises(ValueError):
310310
live_client.start()
311311

312+
312313
def test_live_start_before_subscribe(
313314
live_client: client.Live,
314315
) -> None:
@@ -318,6 +319,7 @@ def test_live_start_before_subscribe(
318319
with pytest.raises(ValueError):
319320
live_client.start()
320321

322+
321323
@pytest.mark.parametrize(
322324
"schema",
323325
[pytest.param(schema, id=str(schema)) for schema in Schema.variants()],
@@ -428,6 +430,34 @@ def test_live_stop(
428430
live_client.block_for_close()
429431

430432

433+
@pytest.mark.usefixtures("mock_live_server")
434+
def test_live_shutdown_remove_closed_stream(
435+
tmp_path: pathlib.Path,
436+
live_client: client.Live,
437+
) -> None:
438+
"""
439+
Test that closed streams are removed upon disconnection.
440+
"""
441+
live_client.subscribe(
442+
dataset=Dataset.GLBX_MDP3,
443+
schema=Schema.MBO,
444+
)
445+
446+
output = tmp_path / "output.dbn"
447+
448+
with output.open("wb") as out:
449+
live_client.add_stream(out)
450+
451+
assert live_client.is_connected() is True
452+
453+
live_client.start()
454+
455+
live_client.stop()
456+
live_client.block_for_close()
457+
458+
assert live_client._user_streams == {}
459+
460+
431461
def test_live_stop_twice(
432462
live_client: client.Live,
433463
) -> None:
@@ -575,6 +605,15 @@ def test_live_add_stream_invalid(
575605
with pytest.raises(ValueError):
576606
live_client.add_stream(readable_file.open(mode="rb"))
577607

608+
def test_live_add_stream_path_directory(
609+
tmp_path: pathlib.Path,
610+
live_client: client.Live,
611+
) -> None:
612+
"""
613+
Test that passing a path to a directory raises an OSError.
614+
"""
615+
with pytest.raises(OSError):
616+
live_client.add_stream(tmp_path)
578617

579618
@pytest.mark.skipif(platform.system() == "Windows", reason="flaky on windows runner")
580619
async def test_live_async_iteration(
@@ -730,6 +769,7 @@ def test_live_sync_iteration(
730769
assert isinstance(records[2], databento_dbn.MBOMsg)
731770
assert isinstance(records[3], databento_dbn.MBOMsg)
732771

772+
733773
async def test_live_callback(
734774
live_client: client.Live,
735775
) -> None:
@@ -800,6 +840,44 @@ async def test_live_stream_to_dbn(
800840
assert output.read_bytes() == expected_data.read()
801841

802842

843+
@pytest.mark.parametrize(
844+
"schema",
845+
(pytest.param(schema, id=str(schema)) for schema in Schema.variants()),
846+
)
847+
async def test_live_stream_to_dbn_from_path(
848+
tmp_path: pathlib.Path,
849+
test_data_path: Callable[[Schema], pathlib.Path],
850+
live_client: client.Live,
851+
schema: Schema,
852+
) -> None:
853+
"""
854+
Test that DBN data streamed by the MockLiveServer is properly re-
855+
constructed client side when specifying a file as a path.
856+
"""
857+
output = tmp_path / "output.dbn"
858+
859+
live_client.subscribe(
860+
dataset=Dataset.GLBX_MDP3,
861+
schema=schema,
862+
stype_in=SType.RAW_SYMBOL,
863+
symbols="TEST",
864+
)
865+
live_client.add_stream(output)
866+
867+
live_client.start()
868+
869+
await live_client.wait_for_close()
870+
871+
expected_data = BytesIO(
872+
zstandard.ZstdDecompressor()
873+
.stream_reader(test_data_path(schema).open("rb"))
874+
.read(),
875+
)
876+
expected_data.seek(0) # rewind
877+
878+
assert output.read_bytes() == expected_data.read()
879+
880+
803881
@pytest.mark.parametrize(
804882
"schema",
805883
(pytest.param(schema, id=str(schema)) for schema in Schema.variants()),

0 commit comments

Comments
 (0)