Skip to content

Commit e832205

Browse files
committed
support writing pandas DataFrame and Series
1 parent dc8a6e8 commit e832205

File tree

5 files changed

+132
-58
lines changed

5 files changed

+132
-58
lines changed

pymarketstore/grpc_client.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import grpc
22
import logging
33
import numpy as np
4+
import pandas as pd
45

56
from typing import List, Union
67

78
from .params import Params, ListSymbolsFormat
89
from .proto import marketstore_pb2 as proto
910
from .proto import marketstore_pb2_grpc as gp
1011
from .results import QueryReply
11-
from .utils import is_iterable
12+
from .utils import is_iterable, timeseries_data_to_write_request
1213

1314
logger = logging.getLogger(__name__)
1415

@@ -27,37 +28,18 @@ def query(self, params: Union[Params, List[Params]]) -> QueryReply:
2728
reply = self.stub.Query(self._build_query(params))
2829
return QueryReply.from_grpc_response(reply)
2930

30-
def write(self, recarray: np.array, tbk: str, isvariablelength: bool = False) -> proto.MultiServerResponse:
31-
types = [
32-
recarray.dtype[name].str.replace('<', '')
33-
for name in recarray.dtype.names
34-
]
35-
names = recarray.dtype.names
36-
data = [
37-
bytes(memoryview(recarray[name]))
38-
for name in recarray.dtype.names
39-
]
40-
length = len(recarray)
41-
start_index = {tbk: 0}
42-
lengths = {tbk: len(recarray)}
43-
44-
req = proto.MultiWriteRequest(requests=[
45-
proto.WriteRequest(
46-
data=proto.NumpyMultiDataset(
47-
data=proto.NumpyDataset(
48-
column_types=types,
49-
column_names=names,
50-
column_data=data,
51-
length=length,
52-
# data_shapes = [],
53-
),
54-
start_index=start_index,
55-
lengths=lengths,
56-
),
57-
is_variable_length=isvariablelength,
58-
)
59-
])
60-
31+
def write(self, data: Union[pd.DataFrame, pd.Series, np.ndarray, np.recarray],
32+
tbk: str,
33+
isvariablelength: bool = False,
34+
) -> proto.MultiServerResponse:
35+
req = proto.MultiWriteRequest(requests=[dict(
36+
data=dict(
37+
data=timeseries_data_to_write_request(data, tbk),
38+
start_index={tbk: 0},
39+
lengths={tbk: len(data)},
40+
),
41+
is_variable_length=isvariablelength,
42+
)])
6143
return self.stub.Write(req)
6244

6345
def _build_query(self, params: Union[Params, List[Params]]) -> proto.MultiQueryRequest:

pymarketstore/jsonrpc_client.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import numpy as np
3+
import pandas as pd
34
import re
45
import requests
56

@@ -9,7 +10,7 @@
910
from .params import Params, ListSymbolsFormat
1011
from .results import QueryReply
1112
from .stream import StreamConn
12-
from .utils import is_iterable
13+
from .utils import is_iterable, timeseries_data_to_write_request
1314

1415
logger = logging.getLogger(__name__)
1516

@@ -36,31 +37,21 @@ def query(self, params: Union[Params, List[Params]]) -> QueryReply:
3637
])
3738
return QueryReply.from_response(reply)
3839

39-
def write(self, recarray: np.array, tbk: str, isvariablelength: bool = False) -> str:
40-
data = {}
41-
data['types'] = [
42-
recarray.dtype[name].str.replace('<', '')
43-
for name in recarray.dtype.names
44-
]
45-
data['names'] = recarray.dtype.names
46-
data['data'] = [
47-
bytes(memoryview(recarray[name]))
48-
for name in recarray.dtype.names
49-
]
50-
data['length'] = len(recarray)
51-
data['startindex'] = {tbk: 0}
52-
data['lengths'] = {tbk: len(recarray)}
53-
write_request = {}
54-
write_request['dataset'] = data
55-
write_request['is_variable_length'] = isvariablelength
56-
writer = {}
57-
writer['requests'] = [write_request]
58-
59-
try:
60-
return self.rpc.call("DataService.Write", **writer)
61-
except requests.exceptions.ConnectionError:
62-
raise requests.exceptions.ConnectionError(
63-
"Could not contact server")
40+
def write(self, data: Union[pd.DataFrame, pd.Series, np.ndarray, np.recarray],
41+
tbk: str,
42+
isvariablelength: bool = False,
43+
) -> dict:
44+
dataset = timeseries_data_to_write_request(data, tbk)
45+
return self.rpc.call("DataService.Write", requests=[dict(
46+
dataset=dict(
47+
types=dataset['column_types'],
48+
names=dataset['column_names'],
49+
data=dataset['column_data'],
50+
startindex={tbk: 0},
51+
lengths={tbk: len(data)},
52+
),
53+
is_variable_length=isvariablelength,
54+
)])
6455

6556
def list_symbols(self, fmt: ListSymbolsFormat = ListSymbolsFormat.SYMBOL) -> List[str]:
6657
reply = self._request('DataService.ListSymbols', format=fmt.value)

pymarketstore/utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,55 @@ def is_iterable(something: Any) -> bool:
1919
:return: bool. true if something is a list, tuple or set
2020
"""
2121
return isinstance(something, (list, tuple, set))
22+
23+
24+
def timeseries_data_to_write_request(data: Union[pd.DataFrame, pd.Series, np.ndarray, np.recarray],
25+
tbk: str,
26+
) -> dict:
27+
if isinstance(data, (np.ndarray, np.recarray)):
28+
return _np_array_to_dataset_params(data)
29+
elif isinstance(data, pd.Series):
30+
return _pd_series_to_dataset_params(data, tbk)
31+
elif isinstance(data, pd.DataFrame):
32+
return _pd_dataframe_to_dataset_params(data)
33+
raise TypeError('data must be pd.DataFrame, pd.Series, np.ndarray, or np.recarray')
34+
35+
36+
def _np_array_to_dataset_params(data: Union[np.ndarray, np.recarray]) -> dict:
37+
if not data.dtype.names:
38+
raise TypeError('numpy arrays must declare named column dtypes')
39+
40+
return dict(column_types=[data.dtype[name].str.replace('<', '')
41+
for name in data.dtype.names],
42+
column_names=list(data.dtype.names),
43+
column_data=[bytes(memoryview(data[name]))
44+
for name in data.dtype.names],
45+
length=len(data))
46+
47+
48+
def _pd_series_to_dataset_params(data: pd.Series, tbk: str) -> dict:
49+
# single column of data (indexed by timestamp, eg from ohlcv_df['ColName'])
50+
if data.index.name == 'Epoch':
51+
epoch = bytes(memoryview(data.index.to_numpy(dtype='i8') // 10**9))
52+
return dict(column_types=['i8', data.dtype.str.replace('<', '')],
53+
column_names=['Epoch', data.name or tbk.split('/')[-1]],
54+
column_data=[epoch, bytes(memoryview(data.to_numpy()))],
55+
length=len(data))
56+
57+
# single row of data (named indexes for one timestamp, eg from ohlcv_df.iloc[N])
58+
epoch = bytes(memoryview(data.name.to_numpy().astype(dtype='i8') // 10**9))
59+
return dict(column_types=['i8'] + [data.dtype.str.replace('<', '')
60+
for _ in range(0, len(data))],
61+
column_names=['Epoch'] + data.index.to_list(),
62+
column_data=[epoch] + [bytes(memoryview(val)) for val in data.array],
63+
length=1)
64+
65+
66+
def _pd_dataframe_to_dataset_params(data: pd.DataFrame) -> dict:
67+
epoch = bytes(memoryview(data.index.to_numpy(dtype='i8') // 10**9))
68+
return dict(column_types=['i8'] + [dtype.str.replace('<', '')
69+
for dtype in data.dtypes],
70+
column_names=['Epoch'] + data.columns.to_list(),
71+
column_data=[epoch] + [bytes(memoryview(data[col].to_numpy()))
72+
for col in data.columns],
73+
length=len(data))

tests/test_results.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pandas as pd
2+
13
from ast import literal_eval
24
from pymarketstore import results
35

@@ -35,6 +37,11 @@
3537
'version': 'dev'}
3638
""") # noqa: E501
3739

40+
btc_array = results.decode_responses(testdata1['responses'])[0]['BTC/1Min/OHLCV']
41+
btc_bytes = testdata1['responses'][0]['result']['data']
42+
btc_df = pd.DataFrame(btc_array).set_index('Epoch')
43+
btc_df.index = pd.DatetimeIndex(btc_df.index * 10**9, tz='UTC')
44+
3845

3946
def test_results():
4047
reply = results.QueryReply.from_response(testdata1)

tests/test_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import pandas as pd
2+
3+
from pymarketstore.utils import timeseries_data_to_write_request
4+
5+
from .test_results import btc_array, btc_bytes, btc_df
6+
7+
8+
class TestTimeseriesDataToWriteRequest:
9+
def test_np_array(self):
10+
assert timeseries_data_to_write_request(btc_array, 'BTC/1Min/OHLCV') == dict(
11+
column_data=btc_bytes,
12+
column_names=['Epoch', 'Open', 'High', 'Low', 'Close', 'Volume'],
13+
column_types=['i8', 'f8', 'f8', 'f8', 'f8', 'f8'],
14+
length=5,
15+
)
16+
17+
def test_pd_series_indexed_by_timestamp(self):
18+
series = pd.Series(btc_df.Open, index=btc_df.index)
19+
assert timeseries_data_to_write_request(series, 'BTC/1Min/Open') == dict(
20+
column_data=[btc_bytes[0], btc_bytes[1]],
21+
column_names=['Epoch', 'Open'],
22+
column_types=['i8', 'f8'],
23+
length=5,
24+
)
25+
26+
def test_pd_series_row_from_df(self):
27+
series = btc_df.iloc[0]
28+
expected_epoch = bytes(memoryview(series.name.to_numpy().astype(dtype='i8') // 10**9))
29+
assert timeseries_data_to_write_request(series, 'BTC/1Min/OHLCV') == dict(
30+
column_data=[expected_epoch] + [bytes(memoryview(val)) for val in series.array],
31+
column_names=['Epoch', 'Open', 'High', 'Low', 'Close', 'Volume'],
32+
column_types=['i8', 'f8', 'f8', 'f8', 'f8', 'f8'],
33+
length=1,
34+
)
35+
36+
def test_pd_dataframe(self):
37+
assert timeseries_data_to_write_request(btc_df, 'BTC/1Min/OHLCV') == dict(
38+
column_data=btc_bytes,
39+
column_names=['Epoch', 'Open', 'High', 'Low', 'Close', 'Volume'],
40+
column_types=['i8', 'f8', 'f8', 'f8', 'f8', 'f8'],
41+
length=5,
42+
)

0 commit comments

Comments
 (0)