Skip to content

Commit ee19f3f

Browse files
committed
(feat) Add pytest-grpc plugin to mock grpc iteractions
1 parent 0f23864 commit ee19f3f

File tree

11 files changed

+304
-6
lines changed

11 files changed

+304
-6
lines changed

Pipfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ websockets = "*"
2323
[dev-packages]
2424
pytest = "*"
2525
pytest-asyncio = "*"
26+
pytest-grpc = "*"
2627

2728
[requires]
2829
python_version = "3"

Pipfile.lock

Lines changed: 9 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyinjective/async_client.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
import os
1+
import asyncio
22
import time
33
import grpc
44
import aiocron
5-
import logging
65
import datetime
6+
from decimal import Decimal
77
from http.cookies import SimpleCookie
8-
from typing import List, Optional, Tuple, Union
8+
from typing import Dict, List, Optional, Tuple, Union
99

10+
from .core.market import Market
11+
from .core.token import Token
1012
from .exceptions import NotFoundError, EmptyMsgError
1113

1214
from .proto.cosmos.base.abci.v1beta1 import abci_pb2 as abci_type
@@ -179,6 +181,11 @@ def __init__(
179181
start=True,
180182
)
181183

184+
self._tokens_and_markets_initialization_lock = asyncio.Lock()
185+
self._tokens: Optional[Dict[str, Token]] = None
186+
self._spot_markets: Optional[Dict[str, Market]] = None
187+
self._derivative_markets: Optional[Dict]
188+
182189
def get_sequence(self):
183190
current_seq = self.sequence
184191
self.sequence += 1
@@ -976,7 +983,7 @@ async def get_funding_rates(self, market_id: str, **kwargs):
976983
skip=kwargs.get("skip"),
977984
limit=kwargs.get("limit"),
978985
end_time=kwargs.get("end_time"),
979-
)
986+
)
980987
return await self.stubDerivativeExchange.FundingRates(req)
981988

982989
async def get_binary_options_markets(self, **kwargs):
@@ -1006,3 +1013,57 @@ async def stream_account_portfolio(self, account_address: str, **kwargs):
10061013
)
10071014
metadata = await self.load_cookie(type="exchange")
10081015
return self.stubPortfolio.StreamAccountPortfolio.__call__(req, metadata=metadata)
1016+
1017+
async def _initialize_tokens_and_markets(self):
1018+
markets = dict()
1019+
tokens = dict()
1020+
markets_info = (await self.get_spot_markets()).markets
1021+
1022+
for market_info in markets_info:
1023+
base_token = tokens.get(market_info.base_token.symbol)
1024+
if base_token is None:
1025+
base_token = Token(
1026+
name=market_info.base_token.name,
1027+
symbol=market_info.base_token.symbol,
1028+
decimals=market_info.base_token.decimals,
1029+
logo=market_info.base_token.logo,
1030+
)
1031+
tokens[base_token.symbol] = base_token
1032+
base_token.add_source(
1033+
denom=market_info.base_denom,
1034+
symbol=market_info.base_token.symbol,
1035+
address=market_info.base_token.address,
1036+
decimals=market_info.base_token.decimals,
1037+
updated=market_info.base_token.updated_at,
1038+
)
1039+
1040+
quote_token = tokens.get(market_info.base_token.symbol)
1041+
if quote_token is None:
1042+
quote_token = Token(
1043+
name=market_info.quote_token.name,
1044+
symbol=market_info.quote_token.symbol,
1045+
decimals=market_info.quote_token.decimals,
1046+
logo=market_info.quote_token.logo,
1047+
)
1048+
quote_token.add_source(
1049+
denom=market_info.quote_denom,
1050+
symbol=market_info.quote_token.symbol,
1051+
address=market_info.quote_token.address,
1052+
decimals=market_info.quote_token.decimals,
1053+
updated=market_info.quote_token.updated_at,
1054+
)
1055+
1056+
market = Market(
1057+
id=market_info.market_id,
1058+
status=market_info.market_status,
1059+
ticker=market_info.ticker,
1060+
base_token=base_token,
1061+
quote_token=quote_token,
1062+
maker_fee_rate=Decimal(market_info.maker_fee_rate),
1063+
taker_fee_rate=Decimal(market_info.taker_fee_rate),
1064+
service_provider_fee=Decimal(market_info.service_provider_fee),
1065+
min_price_tick_size=Decimal(market_info.min_price_tick_size),
1066+
min_quantity_tick_size=Decimal(market_info.min_quantity_tick_size),
1067+
)
1068+
1069+
markets[market.id] = market

pyinjective/core/__init__.py

Whitespace-only changes.

pyinjective/core/market.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from dataclasses import dataclass
2+
from decimal import Decimal
3+
4+
from pyinjective.core.token import Token
5+
6+
7+
@dataclass
8+
class Market:
9+
id: str
10+
status: str
11+
ticker: str
12+
base_token: Token
13+
quote_token: Token
14+
maker_fee_rate: Decimal
15+
taker_fee_rate: Decimal
16+
service_provider_fee: Decimal
17+
min_price_tick_size: Decimal
18+
min_quantity_tick_size: Decimal

pyinjective/core/token.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import Dict
2+
3+
4+
class TokenSource:
5+
6+
def __init__(
7+
self,
8+
denom: str,
9+
symbol: str,
10+
address: str,
11+
decimals: int,
12+
updated: int,
13+
):
14+
self.denom = denom
15+
self.symbol = symbol
16+
self.address = address
17+
self.decimals = decimals
18+
self.updated = updated
19+
20+
21+
class Token:
22+
23+
def __init__(
24+
self,
25+
name: str,
26+
symbol: str,
27+
decimals: int,
28+
logo: str,
29+
):
30+
self.name = name
31+
self.symbol = symbol
32+
self.decimals = decimals
33+
self.logo = logo
34+
35+
self._sources = Dict[str, TokenSource]
36+
37+
def add_source(self, denom: str, symbol: str, address: str, decimals: int, updated: int):
38+
if denom not in self._sources:
39+
token_source = TokenSource(
40+
denom=denom,
41+
symbol=symbol,
42+
address=address,
43+
decimals=decimals,
44+
updated=updated,
45+
)
46+
self._sources[denom] = token_source

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
DEV_REQUIRED = [
4141
"pytest",
4242
"pytest-asyncio",
43+
"pytest-grpc",
4344
]
4445

4546
# The rest you shouldn't have to touch too much :)

tests/async_client_tests.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,33 @@
44
from pyinjective.constant import Network
55

66
from pyinjective.async_client import AsyncClient
7+
from pyinjective.proto.exchange.injective_spot_exchange_rpc_pb2 import (
8+
MarketRequest,
9+
MarketsRequest,
10+
MarketResponse,
11+
MarketsResponse,
12+
)
13+
from tests.rpc_fixtures.markets_fixtures import inj_token_meta, ape_token_meta, usdt_token_meta, inj_usdt_spot_market, ape_usdt_spot_market
714

15+
@pytest.fixture(scope='module')
16+
def grpc_add_to_server():
17+
from pyinjective.proto.exchange.injective_spot_exchange_rpc_pb2_grpc import add_InjectiveSpotExchangeRPCServicer_to_server
18+
19+
return add_InjectiveSpotExchangeRPCServicer_to_server
20+
21+
22+
@pytest.fixture(scope='module')
23+
def grpc_servicer():
24+
from tests.rpc_fixtures.configurable_servicers import ConfigurableInjectiveSpotExchangeRPCServicer
25+
26+
return ConfigurableInjectiveSpotExchangeRPCServicer()
27+
28+
29+
@pytest.fixture(scope='module')
30+
def grpc_stub_cls(grpc_channel):
31+
from pyinjective.proto.exchange.injective_spot_exchange_rpc_pb2_grpc import InjectiveSpotExchangeRPCStub
32+
33+
return InjectiveSpotExchangeRPCStub
834

935
class TestAsyncClient:
1036

@@ -81,4 +107,34 @@ async def test_get_account_logs_exception(self, caplog):
81107
)
82108
assert (found_log is not None)
83109
assert (found_log[0] == "pyinjective.async_client.AsyncClient")
84-
assert (found_log[1] == logging.DEBUG)
110+
assert (found_log[1] == logging.DEBUG)
111+
112+
@pytest.mark.asyncio
113+
async def test_initialize_tokens_and_markets(
114+
self,
115+
grpc_stub,
116+
grpc_servicer,
117+
inj_usdt_spot_market,
118+
ape_usdt_spot_market
119+
):
120+
grpc_servicer.markets_queue.append(MarketsResponse(
121+
markets=[inj_usdt_spot_market, ape_usdt_spot_market]
122+
))
123+
124+
client = AsyncClient(
125+
network=Network.local(),
126+
insecure=False,
127+
)
128+
129+
client.stubSpotExchange = grpc_stub
130+
131+
await client._initialize_tokens_and_markets()
132+
133+
assert(3 == len(client.tokens))
134+
assert(any((inj_usdt_spot_market.base_token_meta.symbol == token.symbol for token in client.tokens)))
135+
assert (any((inj_usdt_spot_market.quote_token_meta.symbol == token.symbol for token in client.tokens)))
136+
assert (any((ape_usdt_spot_market.base_token_meta.symbol == token.symbol for token in client.tokens)))
137+
138+
assert (2 == len(client.markets))
139+
assert (any((inj_usdt_spot_market.market_id == market.id for market in client.markets)))
140+
assert (any((ape_usdt_spot_market.market_id == market.id for market in client.markets)))

tests/rpc_fixtures/__init__.py

Whitespace-only changes.
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from collections import deque
2+
3+
from pyinjective.proto.exchange.injective_spot_exchange_rpc_pb2 import (
4+
MarketsRequest,
5+
MarketsResponse,
6+
)
7+
from pyinjective.proto.exchange.injective_spot_exchange_rpc_pb2_grpc import InjectiveSpotExchangeRPCServicer
8+
9+
10+
class ConfigurableInjectiveSpotExchangeRPCServicer(InjectiveSpotExchangeRPCServicer):
11+
12+
def __init__(self):
13+
super().__init__()
14+
self.markets_queue = deque()
15+
16+
def Markets(self, request: MarketsRequest, context):
17+
return self.markets_queue.pop()
18+

0 commit comments

Comments
 (0)