Skip to content

Commit 823d56f

Browse files
committed
MOD: Implement Iterable for Bento class
1 parent 98f68bf commit 823d56f

File tree

2 files changed

+90
-9
lines changed

2 files changed

+90
-9
lines changed

databento/common/bento.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,21 @@
22

33
import abc
44
import datetime as dt
5+
import logging
56
from io import BytesIO
67
from os import PathLike
78
from pathlib import Path
8-
from typing import IO, TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
9+
from typing import (
10+
IO,
11+
TYPE_CHECKING,
12+
Any,
13+
Callable,
14+
Dict,
15+
Generator,
16+
List,
17+
Optional,
18+
Union,
19+
)
920

1021
import numpy as np
1122
import pandas as pd
@@ -23,6 +34,8 @@
2334
from databento.common.symbology import ProductIdMappingInterval
2435

2536

37+
logger = logging.getLogger(__name__)
38+
2639
if TYPE_CHECKING:
2740
from databento.historical.client import Historical
2841

@@ -274,6 +287,14 @@ def __init__(self, data_source: DataSource) -> None:
274287
Dict[int, str],
275288
] = {}
276289

290+
def __iter__(self) -> Generator[np.void, None, None]:
291+
for _ in range(self.record_count):
292+
raw = self.reader.read(self.record_size)
293+
rec = np.frombuffer(raw, dtype=STRUCT_MAP[self.schema])
294+
if rec.size == 0:
295+
raise StopIteration
296+
yield rec[0]
297+
277298
def _apply_pretty_ts(self, df: pd.DataFrame) -> pd.DataFrame:
278299
df.index = pd.to_datetime(df.index, utc=True)
279300
for column in df.columns:
@@ -672,14 +693,15 @@ def replay(self, callback: Callable[[Any], None]) -> None:
672693
The callback to the data handler.
673694
674695
"""
675-
dtype = STRUCT_MAP[self.schema]
676-
reader: IO[bytes] = self.reader
677-
while True:
678-
raw: bytes = reader.read(self.record_size)
679-
record = np.frombuffer(raw, dtype=dtype)
680-
if record.size == 0:
681-
break
682-
callback(record[0])
696+
for record in self:
697+
try:
698+
callback(record)
699+
except Exception as exc:
700+
logger.exception(
701+
"exception while replaying to user callback",
702+
exc_info=exc,
703+
)
704+
raise
683705

684706
def request_full_definitions(
685707
self,

tests/test_historical_bento.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,62 @@ def test_mbp_1_to_json_with_all_options_writes_expected_file_to_disk(self) -> No
598598

599599
# Cleanup
600600
os.remove(path)
601+
602+
def test_bento_iterable(self) -> None:
603+
"""
604+
Tests the Bento iterable implementation to ensure records
605+
can be accessed by iteration.
606+
"""
607+
# Arrange
608+
stub_data = get_test_data(schema=Schema.MBO)
609+
bento = Bento.from_bytes(data=stub_data)
610+
611+
record_list = list(bento)
612+
assert (
613+
str(record_list[0])
614+
== "(14, 160, 1, 5482, 1609160400000429831, 647784973705, "
615+
"3722750000000, 1, -128, 0, b'C', b'A', 1609160400000704060, "
616+
"22993, 1170352)"
617+
)
618+
assert (
619+
str(record_list[1])
620+
== "(14, 160, 1, 5482, 1609160400000429831, 647784973705, "
621+
"3722750000000, 1, -128, 0, b'C', b'A', 1609160400000704060, "
622+
"22993, 1170352)"
623+
)
624+
625+
def test_bento_iterable_parallel(self) -> None:
626+
"""
627+
Tests the Bento iterable implementation to ensure iterators are
628+
not stateful. For example, calling next() on one iterator does
629+
not affect another.
630+
"""
631+
# Arrange
632+
stub_data = get_test_data(schema=Schema.MBO)
633+
bento = Bento.from_bytes(data=stub_data)
634+
635+
first = iter(bento)
636+
second = iter(bento)
637+
638+
assert (
639+
str(next(first)) == "(14, 160, 1, 5482, 1609160400000429831, 647784973705, "
640+
"3722750000000, 1, -128, 0, b'C', b'A', 1609160400000704060, "
641+
"22993, 1170352)"
642+
)
643+
assert (
644+
str(next(second))
645+
== "(14, 160, 1, 5482, 1609160400000429831, 647784973705, "
646+
"3722750000000, 1, -128, 0, b'C', b'A', 1609160400000704060, "
647+
"22993, 1170352)"
648+
)
649+
assert (
650+
str(next(second))
651+
== "(14, 160, 1, 5482, 1609160400000429831, 647784973705, "
652+
"3722750000000, 1, -128, 0, b'C', b'A', 1609160400000704060, "
653+
"22993, 1170352)"
654+
)
655+
assert (
656+
str(next(first)) == "(14, 160, 1, 5482, 1609160400000429831, 647784973705, "
657+
"3722750000000, 1, -128, 0, b'C', b'A', 1609160400000704060, "
658+
"22993, 1170352)"
659+
)

0 commit comments

Comments
 (0)