|
16 | 16 | from functools import singledispatchmethod |
17 | 17 | from io import BytesIO |
18 | 18 | from os import PathLike |
| 19 | +from typing import Any |
19 | 20 | from typing import Callable |
20 | 21 | from typing import NewType |
21 | 22 | from typing import TypeVar |
|
32 | 33 | from databento.live.gateway import SessionStart |
33 | 34 | from databento.live.gateway import SubscriptionRequest |
34 | 35 | from databento.live.gateway import parse_gateway_message |
35 | | -from databento_dbn import Metadata |
36 | 36 | from databento_dbn import Schema |
37 | | -from databento_dbn import SType |
38 | 37 |
|
39 | 38 |
|
40 | 39 | LIVE_SERVER_VERSION: str = "1.0.0" |
@@ -100,7 +99,8 @@ def __init__( |
100 | 99 | self._version: str = version |
101 | 100 | self._is_authenticated: bool = False |
102 | 101 | self._is_streaming: bool = False |
103 | | - self._repeater_tasks: set[asyncio.Task[None]] = set() |
| 102 | + self._subscriptions: list[SubscriptionRequest] = [] |
| 103 | + self._tasks: set[asyncio.Task[None]] = set() |
104 | 104 |
|
105 | 105 | self._dbn_path = pathlib.Path(dbn_path) |
106 | 106 | self._user_api_keys = user_api_keys |
@@ -155,6 +155,18 @@ def is_streaming(self) -> bool: |
155 | 155 | """ |
156 | 156 | return self._is_streaming |
157 | 157 |
|
| 158 | + @property |
| 159 | + def dataset_path(self) -> pathlib.Path: |
| 160 | + """ |
| 161 | + The path to the DBN files for serving. |
| 162 | +
|
| 163 | + Returns |
| 164 | + ------- |
| 165 | + Path |
| 166 | +
|
| 167 | + """ |
| 168 | + return self._dbn_path / (self._dataset or "") |
| 169 | + |
158 | 170 | @property |
159 | 171 | def mode(self) -> MockLiveMode: |
160 | 172 | """ |
@@ -205,6 +217,18 @@ def session_id(self) -> str: |
205 | 217 | """ |
206 | 218 | return str(hash(self)) |
207 | 219 |
|
| 220 | + @property |
| 221 | + def subscriptions(self) -> tuple[SubscriptionRequest, ...]: |
| 222 | + """ |
| 223 | + The received subscriptions. |
| 224 | +
|
| 225 | + Returns |
| 226 | + ------- |
| 227 | + tuple[SubscriptionRequest, ...] |
| 228 | +
|
| 229 | + """ |
| 230 | + return tuple(self._subscriptions) |
| 231 | + |
208 | 232 | @property |
209 | 233 | def version(self) -> str: |
210 | 234 | """ |
@@ -353,11 +377,11 @@ def _(self, message: AuthenticationRequest) -> None: |
353 | 377 | logger.info("received CRAM response: %s", message.auth) |
354 | 378 | if self.is_authenticated: |
355 | 379 | logger.error("authentication request sent when already authenticated") |
356 | | - self.__transport.close() |
| 380 | + self.__transport.write_eof() |
357 | 381 | return |
358 | 382 | if self.is_streaming: |
359 | 383 | logger.error("authentication request sent while streaming") |
360 | | - self.__transport.close() |
| 384 | + self.__transport.write_eof() |
361 | 385 | return |
362 | 386 |
|
363 | 387 | _, bucket_id = message.auth.split("-") |
@@ -400,54 +424,49 @@ def _(self, message: SubscriptionRequest) -> None: |
400 | 424 | logger.info("received subscription request: %s", str(message).strip()) |
401 | 425 | if not self.is_authenticated: |
402 | 426 | logger.error("subscription request sent while unauthenticated") |
403 | | - self.__transport.close() |
| 427 | + self.__transport.write_eof() |
404 | 428 |
|
405 | | - if self.is_streaming: |
406 | | - logger.error("subscription request sent while streaming") |
407 | | - self.__transport.close() |
| 429 | + self._subscriptions.append(message) |
408 | 430 |
|
409 | | - self._schemas.append(Schema(message.schema)) |
| 431 | + if self.is_streaming: |
| 432 | + self.create_server_task(message) |
410 | 433 |
|
411 | 434 | @handle_client_message.register(SessionStart) |
412 | 435 | def _(self, message: SessionStart) -> None: |
413 | 436 | logger.info("received session start request: %s", str(message).strip()) |
414 | 437 | self._is_streaming = True |
415 | 438 |
|
416 | | - dataset_path = self._dbn_path / (self._dataset or "") |
| 439 | + for sub in self.subscriptions: |
| 440 | + self.create_server_task(sub) |
| 441 | + |
| 442 | + def create_server_task(self, message: SubscriptionRequest) -> None: |
417 | 443 | if self.mode is MockLiveMode.REPLAY: |
418 | | - for schema in self._schemas: |
419 | | - for test_data_path in dataset_path.glob(f"*{schema}.dbn.zst"): |
420 | | - decompressor = zstandard.ZstdDecompressor().stream_reader( |
421 | | - test_data_path.read_bytes(), |
422 | | - ) |
423 | | - logger.info( |
424 | | - "streaming %s for %s schema", |
425 | | - test_data_path.name, |
426 | | - schema, |
427 | | - ) |
428 | | - self.__transport.write(decompressor.readall()) |
| 444 | + task = asyncio.create_task(self.replay_task(schema=Schema(message.schema))) |
| 445 | + else: |
| 446 | + task = asyncio.create_task(self.repeater_task(schema=Schema(message.schema))) |
429 | 447 |
|
430 | | - logger.info( |
431 | | - "data streaming for %d schema(s) completed", |
432 | | - len(self._schemas), |
433 | | - ) |
| 448 | + self._tasks.add(task) |
| 449 | + task.add_done_callback(self._tasks.remove) |
| 450 | + task.add_done_callback(self.check_done) |
434 | 451 |
|
| 452 | + def check_done(self, _: Any) -> None: |
| 453 | + if not self._tasks: |
| 454 | + logger.info("streaming tasks completed") |
435 | 455 | self.__transport.write_eof() |
436 | | - self.__transport.close() |
437 | 456 |
|
438 | | - elif self.mode is MockLiveMode.REPEAT: |
439 | | - metadata = Metadata("UNIT.TEST", 0, SType.RAW_SYMBOL, [], [], [], []) |
440 | | - self.__transport.write(bytes(metadata)) |
441 | | - |
442 | | - loop = asyncio.get_event_loop() |
443 | | - for schema in self._schemas: |
444 | | - task = loop.create_task(self.repeater(schema)) |
445 | | - self._repeater_tasks.add(task) |
446 | | - task.add_done_callback(self._repeater_tasks.remove) |
447 | | - else: |
448 | | - raise ValueError(f"unsupported mode {MockLiveMode.REPEAT}") |
| 457 | + async def replay_task(self, schema: Schema) -> None: |
| 458 | + for test_data_path in self.dataset_path.glob(f"*{schema}.dbn.zst"): |
| 459 | + decompressor = zstandard.ZstdDecompressor().stream_reader( |
| 460 | + test_data_path.read_bytes(), |
| 461 | + ) |
| 462 | + logger.info( |
| 463 | + "streaming %s for %s schema", |
| 464 | + test_data_path.name, |
| 465 | + schema, |
| 466 | + ) |
| 467 | + self.__transport.write(decompressor.readall()) |
449 | 468 |
|
450 | | - async def repeater(self, schema: Schema) -> None: |
| 469 | + async def repeater_task(self, schema: Schema) -> None: |
451 | 470 | struct = SCHEMA_STRUCT_MAP[schema] |
452 | 471 | repeated = bytes(struct(*[0] * 12)) # for now we only support MBP_1 |
453 | 472 |
|
|
0 commit comments