|
4 | 4 | from datetime import datetime, timezone |
5 | 5 | from decimal import Decimal |
6 | 6 | from enum import Enum |
7 | | -from typing import Any, AsyncGenerator, DefaultDict, Dict, List, NoReturn, Optional, Set, Tuple, cast |
| 7 | +from functools import wraps |
| 8 | +from typing import Any, AsyncGenerator, Callable, DefaultDict, Dict, List, NoReturn, Optional, Set, Tuple, cast |
8 | 9 |
|
9 | 10 | from aiohttp import ClientResponseError |
10 | 11 | from aiosignalrcore.hub.base_hub_connection import BaseHubConnection # type: ignore |
@@ -340,7 +341,7 @@ def __init__( |
340 | 341 | self._transaction_subscriptions: Set[str] = set() |
341 | 342 | self._origination_subscriptions: bool = False |
342 | 343 | self._big_map_subscriptions: Dict[str, Set[str]] = {} |
343 | | - self._client: Optional[BaseHubConnection] = None |
| 344 | + self._ws_client: Optional[BaseHubConnection] = None |
344 | 345 |
|
345 | 346 | self._block_cache: BlockCache = BlockCache() |
346 | 347 | self._head: Optional[Head] = None |
@@ -587,41 +588,53 @@ async def add_index(self, index_config: ResolvedIndexConfigT) -> None: |
587 | 588 |
|
588 | 589 | await self._on_connect() |
589 | 590 |
|
590 | | - def _get_client(self) -> BaseHubConnection: |
| 591 | + def _get_ws_client(self) -> BaseHubConnection: |
591 | 592 | """Create SignalR client, register message callbacks""" |
592 | | - if self._client is None: |
593 | | - self._logger.info('Creating websocket client') |
594 | | - self._client = ( |
595 | | - HubConnectionBuilder() |
596 | | - .with_url(self._http._url + '/v1/events') |
597 | | - .with_automatic_reconnect( |
598 | | - { |
599 | | - "type": "raw", |
600 | | - "keep_alive_interval": 10, |
601 | | - "reconnect_interval": 5, |
602 | | - "max_attempts": 5, |
603 | | - } |
604 | | - ) |
605 | | - ).build() |
| 593 | + if self._ws_client: |
| 594 | + return self._ws_client |
| 595 | + |
| 596 | + self._logger.info('Creating websocket client') |
| 597 | + self._ws_client = ( |
| 598 | + HubConnectionBuilder() |
| 599 | + .with_url(self._http._url + '/v1/events') |
| 600 | + .with_automatic_reconnect( |
| 601 | + { |
| 602 | + "type": "raw", |
| 603 | + "keep_alive_interval": 10, |
| 604 | + "reconnect_interval": 5, |
| 605 | + "max_attempts": 5, |
| 606 | + } |
| 607 | + ) |
| 608 | + ).build() |
| 609 | + |
| 610 | + _ws_lock = asyncio.Lock() |
| 611 | + |
| 612 | + def _lock_wrapper(fn: Callable): |
| 613 | + @wraps(fn) |
| 614 | + async def _wrapper(*args, **kwargs): |
| 615 | + async with _ws_lock: |
| 616 | + return await fn(*args, **kwargs) |
| 617 | + |
| 618 | + return _wrapper |
606 | 619 |
|
607 | | - self._client.on_open(self._on_connect) |
608 | | - self._client.on_error(self._on_error) |
609 | | - self._client.on('operations', self._on_operation_message) |
610 | | - self._client.on('bigmaps', self._on_big_map_message) |
611 | | - self._client.on('head', self._on_head_message) |
| 620 | + self._ws_client.on_open(_lock_wrapper(self._on_connect)) |
| 621 | + self._ws_client.on_error(_lock_wrapper(self._on_error)) |
| 622 | + self._ws_client.on('operations', _lock_wrapper(self._on_operation_message)) |
| 623 | + self._ws_client.on('bigmaps', _lock_wrapper(self._on_big_map_message)) |
| 624 | + self._ws_client.on('head', _lock_wrapper(self._on_head_message)) |
612 | 625 |
|
613 | | - return self._client |
| 626 | + return self._ws_client |
614 | 627 |
|
615 | 628 | async def run(self) -> None: |
616 | 629 | """Main loop. Sync indexes via REST, start WS connection""" |
617 | 630 | self._logger.info('Starting datasource') |
618 | 631 |
|
619 | 632 | self._logger.info('Starting websocket client') |
620 | | - await self._get_client().start() |
| 633 | + await self._get_ws_client().start() |
621 | 634 |
|
622 | 635 | async def _on_connect(self) -> None: |
623 | 636 | """Subscribe to all required channels on established WS connection""" |
624 | | - if self._get_client().transport.state != ConnectionState.connected: |
| 637 | + if self._get_ws_client().transport.state != ConnectionState.connected: |
625 | 638 | return |
626 | 639 |
|
627 | 640 | self._logger.info('Realtime connection established, subscribing to channels') |
@@ -932,7 +945,7 @@ def convert_quote(cls, quote_json: Dict[str, Any]) -> QuoteData: |
932 | 945 | ) |
933 | 946 |
|
934 | 947 | async def _send(self, method: str, arguments: List[Dict[str, Any]], on_invocation=None) -> None: |
935 | | - client = self._get_client() |
| 948 | + client = self._get_ws_client() |
936 | 949 | while client.transport.state != ConnectionState.connected: |
937 | 950 | await asyncio.sleep(0.1) |
938 | 951 | await client.send(method, arguments, on_invocation) |
|
0 commit comments