Skip to content

Commit 5e0dc09

Browse files
Wrap WS event callbacks with lock (#134)
1 parent 1ca5ab0 commit 5e0dc09

File tree

1 file changed

+39
-26
lines changed

1 file changed

+39
-26
lines changed

src/dipdup/datasources/tzkt/datasource.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from datetime import datetime, timezone
55
from decimal import Decimal
66
from enum import Enum
7-
from typing import Any, AsyncGenerator, DefaultDict, Dict, List, NoReturn, Optional, Set, Tuple, cast
7+
from functools import wraps
8+
from typing import Any, AsyncGenerator, Callable, DefaultDict, Dict, List, NoReturn, Optional, Set, Tuple, cast
89

910
from aiohttp import ClientResponseError
1011
from aiosignalrcore.hub.base_hub_connection import BaseHubConnection # type: ignore
@@ -340,7 +341,7 @@ def __init__(
340341
self._transaction_subscriptions: Set[str] = set()
341342
self._origination_subscriptions: bool = False
342343
self._big_map_subscriptions: Dict[str, Set[str]] = {}
343-
self._client: Optional[BaseHubConnection] = None
344+
self._ws_client: Optional[BaseHubConnection] = None
344345

345346
self._block_cache: BlockCache = BlockCache()
346347
self._head: Optional[Head] = None
@@ -587,41 +588,53 @@ async def add_index(self, index_config: ResolvedIndexConfigT) -> None:
587588

588589
await self._on_connect()
589590

590-
def _get_client(self) -> BaseHubConnection:
591+
def _get_ws_client(self) -> BaseHubConnection:
591592
"""Create SignalR client, register message callbacks"""
592-
if self._client is None:
593-
self._logger.info('Creating websocket client')
594-
self._client = (
595-
HubConnectionBuilder()
596-
.with_url(self._http._url + '/v1/events')
597-
.with_automatic_reconnect(
598-
{
599-
"type": "raw",
600-
"keep_alive_interval": 10,
601-
"reconnect_interval": 5,
602-
"max_attempts": 5,
603-
}
604-
)
605-
).build()
593+
if self._ws_client:
594+
return self._ws_client
595+
596+
self._logger.info('Creating websocket client')
597+
self._ws_client = (
598+
HubConnectionBuilder()
599+
.with_url(self._http._url + '/v1/events')
600+
.with_automatic_reconnect(
601+
{
602+
"type": "raw",
603+
"keep_alive_interval": 10,
604+
"reconnect_interval": 5,
605+
"max_attempts": 5,
606+
}
607+
)
608+
).build()
609+
610+
_ws_lock = asyncio.Lock()
611+
612+
def _lock_wrapper(fn: Callable):
613+
@wraps(fn)
614+
async def _wrapper(*args, **kwargs):
615+
async with _ws_lock:
616+
return await fn(*args, **kwargs)
617+
618+
return _wrapper
606619

607-
self._client.on_open(self._on_connect)
608-
self._client.on_error(self._on_error)
609-
self._client.on('operations', self._on_operation_message)
610-
self._client.on('bigmaps', self._on_big_map_message)
611-
self._client.on('head', self._on_head_message)
620+
self._ws_client.on_open(_lock_wrapper(self._on_connect))
621+
self._ws_client.on_error(_lock_wrapper(self._on_error))
622+
self._ws_client.on('operations', _lock_wrapper(self._on_operation_message))
623+
self._ws_client.on('bigmaps', _lock_wrapper(self._on_big_map_message))
624+
self._ws_client.on('head', _lock_wrapper(self._on_head_message))
612625

613-
return self._client
626+
return self._ws_client
614627

615628
async def run(self) -> None:
616629
"""Main loop. Sync indexes via REST, start WS connection"""
617630
self._logger.info('Starting datasource')
618631

619632
self._logger.info('Starting websocket client')
620-
await self._get_client().start()
633+
await self._get_ws_client().start()
621634

622635
async def _on_connect(self) -> None:
623636
"""Subscribe to all required channels on established WS connection"""
624-
if self._get_client().transport.state != ConnectionState.connected:
637+
if self._get_ws_client().transport.state != ConnectionState.connected:
625638
return
626639

627640
self._logger.info('Realtime connection established, subscribing to channels')
@@ -932,7 +945,7 @@ def convert_quote(cls, quote_json: Dict[str, Any]) -> QuoteData:
932945
)
933946

934947
async def _send(self, method: str, arguments: List[Dict[str, Any]], on_invocation=None) -> None:
935-
client = self._get_client()
948+
client = self._get_ws_client()
936949
while client.transport.state != ConnectionState.connected:
937950
await asyncio.sleep(0.1)
938951
await client.send(method, arguments, on_invocation)

0 commit comments

Comments
 (0)