Skip to content

Commit 7e992ca

Browse files
committed
use aiohttp.websockets instead of websockets
1 parent 7c04165 commit 7e992ca

File tree

6 files changed

+28
-28
lines changed

6 files changed

+28
-28
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ install:
2222

2323
script:
2424
- py.test --cov=gdax --cov-report=term --cov-append tests
25-
25+
2626
after_success:
2727
- codecov
2828

TODO.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# TODO
22

33
- improve test coverage, especially for the order book
4-
- use aiohttp.websockets instead of websockets
54
- client-side rate limiting
65
- convert return API values from string to Decimal
76
- better enforce API rules

gdax/orderbook.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414
from bintrees import FastRBTree
1515
import aiofiles
16-
import websockets
16+
import aiohttp
17+
# import websockets
1718

1819
import gdax.trader
1920
import gdax.utils
@@ -43,11 +44,14 @@ def __init__(self, product_ids='ETH-USD', api_key=None, api_secret=None,
4344
self._asks = {product_id: FastRBTree() for product_id in product_ids}
4445
self._bids = {product_id: FastRBTree() for product_id in product_ids}
4546
self._sequences = {product_id: None for product_id in product_ids}
46-
self._ws = None
47+
self._ws_session = None
4748
self._ws_connect = None
49+
self._ws = None
4850

4951
async def _init(self):
50-
self._ws_connect = websockets.connect('wss://ws-feed.gdax.com')
52+
self._ws_session = aiohttp.ClientSession()
53+
self._ws_connect = self._ws_session.ws_connect(
54+
'wss://ws-feed.gdax.com')
5155
self._ws = await self._ws_connect.__aenter__()
5256

5357
# subscribe
@@ -89,7 +93,7 @@ async def __aenter__(self):
8993
return self
9094

9195
async def __aexit__(self, exc_type, exc, traceback):
92-
return await self._ws_connect.__aexit__(exc_type, exc, traceback)
96+
return await self._ws_session.__aexit__(exc_type, exc, traceback)
9397

9498
async def _open_log_file(self):
9599
if self.trade_log_file_path is not None:
@@ -101,10 +105,10 @@ async def _close_log_file(self):
101105
await self._trade_file.__aexit__(None, None, None)
102106

103107
async def _send(self, **kwargs):
104-
await self._ws.send(json.dumps(kwargs))
108+
await self._ws.send_json(kwargs)
105109

106110
async def _recv(self):
107-
json_data = await self._ws.recv()
111+
json_data = await self._ws.receive_str()
108112
if self._trade_file:
109113
await self._trade_file.write(f'W {json_data}\n')
110114
return json.loads(json_data)
@@ -132,9 +136,11 @@ async def _subscribe(self):
132136
async def handle_message(self):
133137
try:
134138
message = await self._recv()
135-
except websockets.exceptions.ConnectionClosed:
136-
await self._ws_connect.__aexit__(None, None, None)
137-
self._init()
139+
except aiohttp.ServerDisconnectedError as exc:
140+
logging.error(
141+
f'Error: Exception: f{exc}. Re-initializing websocket.')
142+
await self._ws_session.__aexit__(None, None, None)
143+
await self._init()
138144
return
139145

140146
product_id = message['product_id']
@@ -146,10 +152,10 @@ async def handle_message(self):
146152
# from getProductOrderBook)
147153
return message
148154
elif sequence > self._sequences[product_id] + 1:
149-
logging.info(
155+
logging.error(
150156
'Error: messages missing ({} - {}). Re-initializing websocket.'
151157
.format(sequence, self._sequences[product_id]))
152-
await self._ws_connect.__aexit__(None, None, None)
158+
await self._ws_session.__aexit__(None, None, None)
153159
await self._init()
154160
return
155161

@@ -243,7 +249,6 @@ def match(self, product_id, order):
243249
self.set_asks(product_id, price, asks)
244250

245251
def change(self, product_id, order):
246-
logging.info((product_id, order))
247252
if 'new_size' not in order:
248253
# market order
249254
# TODO

requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1+
aiofiles==0.3.1
12
aiohttp==2.2.0
23
async-timeout==1.2.1
34
requests==2.18.1
4-
websockets==3.3
5-
aiofiles==0.3.1
65
bintrees==2.0.7

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
author_email='[email protected]',
1313
install_requires=[
1414
'aiohttp==2.2.0',
15+
'aiofiles==0.3.1',
1516
'async-timeout==1.2.1',
1617
'requests==2.18.1',
17-
'websockets==3.3',
1818
'bintrees==2.0.7',
1919
],
2020
packages=find_packages(),

tests/test_orderbook.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -138,20 +138,19 @@ def generate_id():
138138

139139

140140
@pytest.mark.asyncio
141-
@patch('websockets.connect', new_callable=AsyncContextManagerMock)
141+
@patch('aiohttp.ClientSession.ws_connect',
142+
new_callable=AsyncContextManagerMock)
142143
class TestOrderbook(object):
143144
@patch('gdax.trader.Trader.get_product_order_book')
144145
async def test_basic_init(self, mock_book, mock_connect):
145-
mock_connect.return_value.aenter.send = CoroutineMock()
146+
mock_connect.return_value.aenter.send_json = CoroutineMock()
146147
mock_book.return_value = test_book
147148

148149
product_id = 'ETH-USD'
149150
product_ids = [product_id]
150151
async with gdax.orderbook.OrderBook(product_ids) as orderbook:
151152
msg = {'type': 'subscribe', 'product_ids': product_ids}
152-
subscribe_msg = json.dumps(msg)
153-
mock_connect.return_value.aenter.send.assert_called_with(
154-
subscribe_msg)
153+
mock_connect.return_value.aenter.send_json.assert_called_with(msg)
155154

156155
mock_book.assert_called_with(level=3)
157156

@@ -195,15 +194,13 @@ async def test_basic_init(self, mock_book, mock_connect):
195194

196195
@patch('gdax.trader.Trader.get_product_order_book')
197196
async def test_heartbeat(self, mock_book, mock_connect):
198-
mock_connect.return_value.aenter.send = CoroutineMock()
197+
mock_connect.return_value.aenter.send_json = CoroutineMock()
199198

200199
mock_book.return_value = {'bids': [], 'asks': [], 'sequence': 1}
201200
product_ids = ['ETH-USD']
202201
async with gdax.orderbook.OrderBook(product_ids,
203202
use_heartbeat=True) as orderbook:
204-
msg1 = {'type': 'subscribe', 'product_ids': product_ids}
205-
msg2 = {'type': 'heartbeat', 'on': True}
206-
subscribe_msg = json.dumps(msg1)
207-
heartbeat_msg = json.dumps(msg2)
203+
subscribe_msg = {'type': 'subscribe', 'product_ids': product_ids}
204+
heartbeat_msg = {'type': 'heartbeat', 'on': True}
208205
calls = [call(subscribe_msg), call(heartbeat_msg)]
209-
mock_connect.return_value.aenter.send.assert_has_calls(calls)
206+
mock_connect.return_value.aenter.send_json.assert_has_calls(calls)

0 commit comments

Comments
 (0)