Skip to content

Commit 649f700

Browse files
committed
MOD: Improve error handling for numpy decoding
1 parent 52be9d5 commit 649f700

File tree

2 files changed

+87
-4
lines changed

2 files changed

+87
-4
lines changed

databento/common/dbnstore.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from databento.common.enums import Compression, Schema, SType
3333
from databento.common.metadata import MetadataDecoder
3434
from databento.common.symbology import ProductIdMappingInterval
35+
from databento.historical.error import BentoError
3536

3637

3738
logger = logging.getLogger(__name__)
@@ -345,11 +346,18 @@ def __init__(self, data_source: DataSource) -> None:
345346

346347
def __iter__(self) -> Generator[np.void, None, None]:
347348
reader = self.reader
349+
dtype = STRUCT_MAP[self.schema]
348350
while True:
349351
raw = reader.read(self.record_size)
350352
if raw:
351-
rec = np.frombuffer(raw, dtype=STRUCT_MAP[self.schema])
352-
yield rec[0]
353+
try:
354+
rec = np.frombuffer(raw, dtype)
355+
except ValueError as value_error:
356+
raise BentoError(
357+
f"Error decoding {len(raw)} bytes for {self.schema} iteration",
358+
) from value_error
359+
else:
360+
yield rec[0]
353361
else:
354362
break
355363

@@ -971,4 +979,11 @@ def to_ndarray(self) -> np.ndarray[Any, Any]:
971979
972980
"""
973981
data: bytes = self.reader.read()
974-
return np.frombuffer(data, dtype=self.dtype)
982+
try:
983+
nd_array = np.frombuffer(data, dtype=self.dtype)
984+
except ValueError as value_error:
985+
raise BentoError(
986+
f"Error decoding {len(data)} bytes to {self.schema} `ndarray`",
987+
) from value_error
988+
else:
989+
return nd_array

tests/test_historical_bento.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import zstandard
1212
from databento.common.dbnstore import DBNStore
1313
from databento.common.enums import Schema, SType
14+
from databento.historical.error import BentoError
1415

1516

1617
def test_from_file_when_not_exists_raises_expected_exception() -> None:
@@ -820,11 +821,78 @@ def test_dbnstore_compression_equality(
820821
with zstandard by default.
821822
"""
822823
zstd_stub_data = test_data(schema)
823-
dbn_stub_data = zstandard.ZstdDecompressor().stream_reader(test_data(schema)).read()
824+
dbn_stub_data = zstandard.ZstdDecompressor().stream_reader(zstd_stub_data).read()
824825

825826
zstd_dbnstore = DBNStore.from_bytes(zstd_stub_data)
826827
dbn_dbnstore = DBNStore.from_bytes(dbn_stub_data)
827828

828829
assert len(zstd_dbnstore.to_ndarray()) == len(dbn_dbnstore.to_ndarray())
829830
assert zstd_dbnstore.metadata == dbn_dbnstore.metadata
830831
assert zstd_dbnstore.reader.read() == dbn_dbnstore.reader.read()
832+
833+
834+
def test_dbnstore_buffer_short(
835+
test_data: Callable[[Schema], bytes],
836+
tmp_path: Path,
837+
) -> None:
838+
"""
839+
Test that creating a DBNStore with missing bytes raises a
840+
BentoError when decoding.
841+
"""
842+
# Arrange
843+
dbn_stub_data = (
844+
zstandard.ZstdDecompressor().stream_reader(test_data(Schema.MBO)).read()
845+
)
846+
847+
# Act
848+
dbnstore = DBNStore.from_bytes(data=dbn_stub_data[:-2])
849+
850+
# Assert
851+
with pytest.raises(BentoError):
852+
list(dbnstore)
853+
854+
with pytest.raises(BentoError):
855+
dbnstore.to_ndarray()
856+
857+
with pytest.raises(BentoError):
858+
dbnstore.to_df()
859+
860+
with pytest.raises(BentoError):
861+
dbnstore.to_csv(tmp_path / "test.csv")
862+
863+
with pytest.raises(BentoError):
864+
dbnstore.to_json(tmp_path / "test.json")
865+
866+
867+
def test_dbnstore_buffer_long(
868+
test_data: Callable[[Schema], bytes],
869+
tmp_path: Path,
870+
) -> None:
871+
"""
872+
Test that creating a DBNStore with excess bytes raises a
873+
BentoError when decoding.
874+
"""
875+
# Arrange
876+
dbn_stub_data = (
877+
zstandard.ZstdDecompressor().stream_reader(test_data(Schema.MBO)).read()
878+
)
879+
880+
# Act
881+
dbn_stub_data += b"\xF0\xFF"
882+
dbnstore = DBNStore.from_bytes(data=dbn_stub_data)
883+
884+
# Assert
885+
with pytest.raises(BentoError):
886+
list(dbnstore)
887+
888+
with pytest.raises(BentoError):
889+
dbnstore.to_ndarray()
890+
891+
with pytest.raises(BentoError):
892+
dbnstore.to_df()
893+
894+
with pytest.raises(BentoError):
895+
dbnstore.to_csv(tmp_path / "test.csv")
896+
897+
with pytest.raises(BentoError):
898+
dbnstore.to_json(tmp_path / "test.json")

0 commit comments

Comments
 (0)