Skip to content

Commit 38677e7

Browse files
committed
FIX: Fix non-int numerical types in symbols list
1 parent 3d6777d commit 38677e7

File tree

6 files changed

+102
-88
lines changed

6 files changed

+102
-88
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- Added `metadata` property to `Live`
88
- Added `DatatbentoLiveProtocol` class
99
- Added support for emitting warnings in API response headers
10+
- Fixed issue with `numpy` types not being handled in symbols field
1011
- Upgraded `aiohttp` to 3.8.3
1112
- Upgraded `numpy` to to 1.23.5
1213
- Upgraded `pandas` to to 1.5.3

databento/common/parsing.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1-
from collections.abc import Iterable as IterableABC
1+
from __future__ import annotations
2+
3+
from collections.abc import Iterable
24
from datetime import date
3-
from functools import partial, singledispatch
4-
from typing import Iterable, Optional, Union
5+
from functools import partial
6+
from functools import singledispatch
7+
from numbers import Number
8+
from typing import Optional, Union
59

610
import pandas as pd
11+
712
from databento.common.enums import SType
813
from databento.common.symbology import ALL_SYMBOLS
914
from databento.common.validation import validate_smart_symbol
@@ -55,15 +60,15 @@ def optional_values_list_to_string(
5560

5661
@singledispatch
5762
def optional_symbols_list_to_string(
58-
symbols: Optional[Union[Iterable[str], Iterable[int], str, int]],
59-
_: SType,
63+
symbols: Optional[Union[Iterable[str], Iterable[Number], str, Number]],
64+
stype_in: SType,
6065
) -> str:
6166
"""
6267
Concatenate a symbols string or iterable of symbol strings (if not None).
6368
6469
Parameters
6570
----------
66-
symbols : iterable of str, iterable of int, str, or int optional
71+
symbols : iterable of str, iterable of Number, str, or Number optional
6772
The symbols to concatenate.
6873
stype_in : SType
6974
The input symbology type for the request.
@@ -98,18 +103,18 @@ def _(_: None, __: SType) -> str:
98103

99104

100105
@optional_symbols_list_to_string.register
101-
def _(symbols: int, stype_in: SType) -> str:
106+
def _(symbols: Number, stype_in: SType) -> str:
102107
"""
103108
Dispatch method for optional_symbols_list_to_string.
104-
Handles int, alerting when an integer is given for
105-
STypes that expect strings.
109+
Handles numerical types, alerting when an integer is
110+
given for STypes that expect strings.
106111
107112
See Also
108113
--------
109114
optional_symbols_list_to_string
110115
111116
"""
112-
if stype_in == SType.INSTRUMENT_ID or stype_in == "product_id":
117+
if stype_in == SType.INSTRUMENT_ID:
113118
return str(symbols)
114119
raise ValueError(
115120
f"value `{symbols}` is not a valid symbol for stype {stype_in}; "
@@ -147,7 +152,7 @@ def _(symbols: str, stype_in: SType) -> str:
147152
return symbols.strip().upper()
148153

149154

150-
@optional_symbols_list_to_string.register(cls=IterableABC)
155+
@optional_symbols_list_to_string.register(cls=Iterable)
151156
def _(symbols: Union[Iterable[str], Iterable[int]], stype_in: SType) -> str:
152157
"""
153158
Dispatch method for optional_symbols_list_to_string.

databento/live/client.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,27 @@
44
import queue
55
import threading
66
from concurrent import futures
7+
from numbers import Number
78
from typing import IO, Callable, Iterable, List, Optional, Union
89

910
import databento_dbn
11+
1012
from databento.common.cram import BUCKET_ID_LENGTH
11-
from databento.common.enums import Dataset, Schema, SType
13+
from databento.common.enums import Dataset
14+
from databento.common.enums import Schema
15+
from databento.common.enums import SType
1216
from databento.common.error import BentoError
13-
from databento.common.parsing import (
14-
optional_datetime_to_unix_nanoseconds,
15-
optional_symbols_list_to_string,
16-
)
17+
from databento.common.parsing import optional_datetime_to_unix_nanoseconds
18+
from databento.common.parsing import optional_symbols_list_to_string
1719
from databento.common.symbology import ALL_SYMBOLS
18-
from databento.common.validation import validate_enum, validate_semantic_string
20+
from databento.common.validation import validate_enum
21+
from databento.common.validation import validate_semantic_string
1922
from databento.live import DBNRecord
20-
from databento.live.session import (
21-
DEFAULT_REMOTE_PORT,
22-
DBNQueue,
23-
Session,
24-
SessionMetadata,
25-
_SessionProtocol,
26-
)
23+
from databento.live.session import DEFAULT_REMOTE_PORT
24+
from databento.live.session import DBNQueue
25+
from databento.live.session import Session
26+
from databento.live.session import SessionMetadata
27+
from databento.live.session import _SessionProtocol
2728

2829

2930
logger = logging.getLogger(__name__)
@@ -358,7 +359,7 @@ def subscribe(
358359
self,
359360
dataset: Union[Dataset, str],
360361
schema: Union[Schema, str],
361-
symbols: Union[Iterable[str], Iterable[int], str, int] = ALL_SYMBOLS,
362+
symbols: Union[Iterable[str], Iterable[Number], str, Number] = ALL_SYMBOLS,
362363
stype_in: Union[SType, str] = SType.RAW_SYMBOL,
363364
start: Optional[Union[str, int]] = None,
364365
) -> None:
@@ -378,7 +379,7 @@ def subscribe(
378379
The dataset for the subscription.
379380
schema : Schema or str
380381
The schema to subscribe to.
381-
symbols : Iterable[Union[str, int]] or str, default 'ALL_SYMBOLS'
382+
symbols : Iterable[Union[str, Number]] or str or Number, default 'ALL_SYMBOLS'
382383
The symbols to subscribe to.
383384
stype_in : SType or str, default 'raw_symbol'
384385
The input symbology type to resolve from.

databento/live/protocol.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
11
import asyncio
22
import logging
3-
from functools import singledispatch, update_wrapper
3+
from functools import singledispatch
4+
from functools import update_wrapper
5+
from numbers import Number
46
from typing import Any, Callable, Iterable, Optional, Union
57

68
import databento_dbn
79

810
from databento.common import cram
9-
from databento.common.enums import Dataset, Schema, SType
11+
from databento.common.enums import Dataset
12+
from databento.common.enums import Schema
13+
from databento.common.enums import SType
1014
from databento.common.error import BentoError
11-
from databento.common.parsing import (
12-
optional_datetime_to_unix_nanoseconds,
13-
optional_symbols_list_to_string,
14-
)
15+
from databento.common.parsing import optional_datetime_to_unix_nanoseconds
16+
from databento.common.parsing import optional_symbols_list_to_string
1517
from databento.common.symbology import ALL_SYMBOLS
16-
from databento.common.validation import validate_enum, validate_semantic_string
17-
from databento.live.gateway import (
18-
AuthenticationRequest,
19-
AuthenticationResponse,
20-
ChallengeRequest,
21-
GatewayControl,
22-
GatewayDecoder,
23-
Greeting,
24-
SessionStart,
25-
SubscriptionRequest,
26-
)
18+
from databento.common.validation import validate_enum
19+
from databento.common.validation import validate_semantic_string
20+
from databento.live.gateway import AuthenticationRequest
21+
from databento.live.gateway import AuthenticationResponse
22+
from databento.live.gateway import ChallengeRequest
23+
from databento.live.gateway import GatewayControl
24+
from databento.live.gateway import GatewayDecoder
25+
from databento.live.gateway import Greeting
26+
from databento.live.gateway import SessionStart
27+
from databento.live.gateway import SubscriptionRequest
28+
2729

2830
DBNRecord = Union[
2931
databento_dbn.MBOMsg,
@@ -282,7 +284,7 @@ def received_record(self, record: DBNRecord) -> None:
282284
def subscribe(
283285
self,
284286
schema: Union[Schema, str],
285-
symbols: Union[Iterable[str], Iterable[int], str, int] = ALL_SYMBOLS,
287+
symbols: Union[Iterable[str], Iterable[Number], str, Number] = ALL_SYMBOLS,
286288
stype_in: Union[SType, str] = SType.RAW_SYMBOL,
287289
start: Optional[Union[str, int]] = None,
288290
) -> None:
@@ -293,7 +295,7 @@ def subscribe(
293295
----------
294296
schema : Schema or str
295297
The schema to subscribe to.
296-
symbols : Iterable[Union[str, int]] or str, default 'ALL_SYMBOLS'
298+
symbols : Iterable[Union[str, Number]] or str or Number, default 'ALL_SYMBOLS'
297299
The symbols to subscribe to.
298300
stype_in : SType or str, default 'raw_symbol'
299301
The input symbology type to resolve from.

databento/live/session.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import queue
55
import struct
66
import threading
7+
from numbers import Number
78
from typing import IO, Callable, Iterable, List, Optional, Set, Union
89

910
import databento_dbn
@@ -353,7 +354,7 @@ def subscribe(
353354
self,
354355
dataset: Union[Dataset, str],
355356
schema: Union[Schema, str],
356-
symbols: Union[Iterable[str], Iterable[int], str, int] = ALL_SYMBOLS,
357+
symbols: Union[Iterable[str], Iterable[Number], str, Number] = ALL_SYMBOLS,
357358
stype_in: Union[SType, str] = SType.RAW_SYMBOL,
358359
start: Optional[Union[str, int]] = None,
359360
) -> None:
@@ -368,7 +369,7 @@ def subscribe(
368369
The dataset for the subscription.
369370
schema : Schema or str
370371
The schema to subscribe to.
371-
symbols : Iterable[Union[str, int]] or str, default 'ALL_SYMBOLS'
372+
symbols : Iterable[Union[str, Number]] or str or Number, default 'ALL_SYMBOLS'
372373
The symbols to subscribe to.
373374
stype_in : SType or str, default 'raw_symbol'
374375
The input symbology type to resolve from.

tests/test_common_parsing.py

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
import datetime as dt
2+
from numbers import Number
23
from typing import Any, List, Optional, Type, Union
34

45
import numpy as np
56
import pandas as pd
67
import pytest
7-
88
from databento.common.enums import SType
9-
from databento.common.parsing import (
10-
optional_date_to_string,
11-
optional_datetime_to_string,
12-
optional_datetime_to_unix_nanoseconds,
13-
optional_symbols_list_to_string,
14-
optional_values_list_to_string,
15-
)
9+
from databento.common.parsing import optional_date_to_string
10+
from databento.common.parsing import optional_datetime_to_string
11+
from databento.common.parsing import optional_datetime_to_unix_nanoseconds
12+
from databento.common.parsing import optional_symbols_list_to_string
13+
from databento.common.parsing import optional_values_list_to_string
14+
1615

1716
# Set the type to `Any` to disable mypy type checking. Used to test if functions
1817
# will raise a `TypeError` when passed an incorrectly-typed argument.
@@ -47,37 +46,6 @@ def test_maybe_values_list_to_string_given_valid_inputs_returns_expected(
4746
assert result == expected
4847

4948

50-
@pytest.mark.parametrize(
51-
"symbols, stype, expected",
52-
[
53-
pytest.param("NVDA", SType.RAW_SYMBOL, "NVDA"),
54-
pytest.param(" nvda ", SType.RAW_SYMBOL, "NVDA"),
55-
pytest.param("NVDA,amd", SType.RAW_SYMBOL, "NVDA,AMD"),
56-
pytest.param("NVDA,amd,NOC,", SType.RAW_SYMBOL, "NVDA,AMD,NOC"),
57-
pytest.param("NVDA, amd,NOC, ", SType.RAW_SYMBOL, "NVDA,AMD,NOC"),
58-
pytest.param(["NVDA", ["NOC", "AMD"]], SType.RAW_SYMBOL, "NVDA,NOC,AMD"),
59-
pytest.param(["NVDA", "NOC,AMD"], SType.RAW_SYMBOL, "NVDA,NOC,AMD"),
60-
pytest.param("", SType.RAW_SYMBOL, ValueError),
61-
pytest.param([""], SType.RAW_SYMBOL, ValueError),
62-
pytest.param(["NVDA", ""], SType.RAW_SYMBOL, ValueError),
63-
pytest.param(["NVDA", [""]], SType.RAW_SYMBOL, ValueError),
64-
],
65-
)
66-
def test_optional_symbols_list_to_string_native(
67-
symbols: Optional[Union[List[int], int]],
68-
stype: SType,
69-
expected: Union[str, Type[Exception]],
70-
) -> None:
71-
"""
72-
Test that str are allowed for SType.RAW_SYMBOL.
73-
"""
74-
if isinstance(expected, str):
75-
assert optional_symbols_list_to_string(symbols, stype) == expected
76-
else:
77-
with pytest.raises(expected):
78-
optional_symbols_list_to_string(symbols, stype)
79-
80-
8149
def test_maybe_symbols_list_to_string_given_invalid_input_raises_type_error() -> None:
8250
# Arrange, Act, Assert
8351
with pytest.raises(TypeError):
@@ -134,7 +102,7 @@ def test_optional_symbols_list_to_string_given_valid_inputs_returns_expected(
134102
],
135103
)
136104
def test_optional_symbols_list_to_string_int(
137-
symbols: Optional[Union[List[int], int]],
105+
symbols: Optional[Union[List[Number], Number]],
138106
stype: SType,
139107
expected: Union[str, Type[Exception]],
140108
) -> None:
@@ -150,6 +118,42 @@ def test_optional_symbols_list_to_string_int(
150118
optional_symbols_list_to_string(symbols, stype)
151119

152120

121+
@pytest.mark.parametrize(
122+
"symbols, stype, expected",
123+
[
124+
pytest.param(np.byte(120), SType.INSTRUMENT_ID, "120"),
125+
pytest.param(np.short(32_000), SType.INSTRUMENT_ID, "32000"),
126+
pytest.param(
127+
[np.intc(12345), np.intc(67890)], SType.INSTRUMENT_ID, "12345,67890",
128+
),
129+
pytest.param(
130+
[np.int_(12345), np.longlong(67890)], SType.INSTRUMENT_ID, "12345,67890",
131+
),
132+
pytest.param(
133+
[np.int_(12345), np.longlong(67890)], SType.INSTRUMENT_ID, "12345,67890",
134+
),
135+
pytest.param(
136+
[np.int_(12345), np.longlong(67890)], SType.INSTRUMENT_ID, "12345,67890",
137+
),
138+
],
139+
)
140+
def test_optional_symbols_list_to_string_numpy(
141+
symbols: Optional[Union[List[Number], Number]],
142+
stype: SType,
143+
expected: Union[str, Type[Exception]],
144+
) -> None:
145+
"""
146+
Test that weird numpy types are allowed for SType.INSTRUMENT_ID.
147+
If integers are given for a different SType we expect
148+
a ValueError.
149+
"""
150+
if isinstance(expected, str):
151+
assert optional_symbols_list_to_string(symbols, stype) == expected
152+
else:
153+
with pytest.raises(expected):
154+
optional_symbols_list_to_string(symbols, stype)
155+
156+
153157
@pytest.mark.parametrize(
154158
"symbols, stype, expected",
155159
[
@@ -167,7 +171,7 @@ def test_optional_symbols_list_to_string_int(
167171
],
168172
)
169173
def test_optional_symbols_list_to_string_raw_symbol(
170-
symbols: Optional[Union[List[int], int]],
174+
symbols: Optional[Union[List[Number], Number]],
171175
stype: SType,
172176
expected: Union[str, Type[Exception]],
173177
) -> None:

0 commit comments

Comments
 (0)