Skip to content

Commit fd0f0cf

Browse files
Merge pull request #157 from InjectiveLabs/f/add_get_account
chore: fix get_account with EthAccount type and cookie
2 parents c2ab981 + 11cb986 commit fd0f0cf

File tree

3 files changed

+66
-31
lines changed

3 files changed

+66
-31
lines changed

examples/chain_client/1_MsgSend.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ async def main() -> None:
3535
# load account
3636
priv_key = PrivateKey.from_hex("f9db9bf330e23cb7839039e944adef6e9df447b90b503d5b4464c90bea9022f3")
3737
pub_key = priv_key.to_public_key()
38-
address = await pub_key.to_address().async_init_num_seq(network.lcd_endpoint)
38+
address = pub_key.to_address()
39+
account = await client.get_account(address.to_acc_bech32())
3940

4041
# prepare tx msg
4142
msg = composer.MsgSend(
@@ -49,8 +50,8 @@ async def main() -> None:
4950
tx = (
5051
Transaction()
5152
.with_messages(msg)
52-
.with_sequence(address.get_sequence())
53-
.with_account_num(address.get_number())
53+
.with_sequence(client.get_sequence())
54+
.with_account_num(client.get_number())
5455
.with_chain_id(network.chain_id)
5556
)
5657
sim_sign_doc = tx.get_sign_doc(pub_key)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import asyncio
2+
import logging
3+
4+
from pyinjective.async_client import AsyncClient
5+
from pyinjective.constant import Network
6+
7+
async def main() -> None:
8+
network = Network.testnet()
9+
client = AsyncClient(network, insecure=False)
10+
address = "inj1knhahceyp57j5x7xh69p7utegnnnfgxavmahjr"
11+
acc = await client.get_account(address=address)
12+
print(acc)
13+
14+
if __name__ == '__main__':
15+
logging.basicConfig(level=logging.INFO)
16+
asyncio.get_event_loop().run_until_complete(main())

pyinjective/async_client.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -58,43 +58,50 @@
5858
injective_auction_rpc_pb2_grpc as auction_rpc_grpc,
5959
)
6060

61+
from proto.injective.types.v1beta1 import (
62+
account_pb2
63+
)
64+
6165
from .constant import Network
6266

6367
DEFAULT_TIMEOUTHEIGHT_SYNC_INTERVAL = 20 # seconds
6468
DEFAULT_TIMEOUTHEIGHT = 30 # blocks
6569
DEFAULT_SESSION_RENEWAL_OFFSET = 120 # seconds
6670
DEFAULT_BLOCK_TIME = 2 # seconds
6771

68-
6972
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
7073

7174

7275
class AsyncClient:
7376
def __init__(
74-
self,
75-
network: Network,
76-
insecure: bool = False,
77-
load_balancer: bool = False,
78-
credentials = grpc.ssl_channel_credentials(),
79-
chain_cookie_location = ".chain_cookie"
77+
self,
78+
network: Network,
79+
insecure: bool = False,
80+
load_balancer: bool = False,
81+
credentials=grpc.ssl_channel_credentials(),
82+
chain_cookie_location=".chain_cookie",
8083
):
8184

8285
# use append mode to create file if not exist
8386
self.chain_cookie_location = chain_cookie_location
8487
cookie_file = open(chain_cookie_location, "a+")
8588
cookie_file.close()
86-
89+
90+
self.addr = ""
91+
self.number = 0
92+
self.sequence = 0
93+
8794
self.cookie_type = None
8895
self.expiration_format = None
8996
self.load_balancer = load_balancer
9097

9198
if self.load_balancer is False:
92-
self.cookie_type = "grpc-cookie"
93-
self.expiration_format = "20{}"
99+
self.cookie_type = "grpc-cookie"
100+
self.expiration_format = "20{}"
94101

95102
else:
96-
self.cookie_type = "GCLB"
97-
self.expiration_format = "{}"
103+
self.cookie_type = "GCLB"
104+
self.expiration_format = "{}"
98105

99106
# chain stubs
100107
self.chain_channel = (
@@ -161,6 +168,14 @@ def __init__(
161168
start=True,
162169
)
163170

171+
def get_sequence(self):
172+
current_seq = self.sequence
173+
self.sequence += 1
174+
return current_seq
175+
176+
def get_number(self):
177+
return self.number
178+
164179
async def get_tx(self, tx_hash):
165180
return await self.stubTx.GetTx(tx_service.GetTxRequest(hash=tx_hash))
166181

@@ -199,9 +214,9 @@ async def renew_cookie(self, existing_cookie, type):
199214
# format cookie date into RFC1123 standard
200215
cookie = SimpleCookie()
201216
cookie.load(existing_cookie)
202-
217+
203218
expires_at = cookie.get(f"{self.cookie_type}").get("expires")
204-
expires_at = expires_at.replace("-"," ")
219+
expires_at = expires_at.replace("-", " ")
205220
yyyy = f"{self.expiration_format}".format(expires_at[12:14])
206221
expires_at = expires_at[:12] + yyyy + expires_at[14:]
207222

@@ -270,16 +285,19 @@ async def get_latest_block(self) -> tendermint_query.GetLatestBlockResponse:
270285
req = tendermint_query.GetLatestBlockRequest()
271286
return await self.stubCosmosTendermint.GetLatestBlock(req)
272287

273-
async def get_account(self, address: str) -> Optional[auth_type.BaseAccount]:
288+
async def get_account(self, address: str) -> Optional[account_pb2.EthAccount]:
274289
try:
275-
account_any = await self.stubAuth.Account(
276-
auth_query.QueryAccountRequest(address=address)
277-
).account
278-
account = auth_type.BaseAccount()
290+
metadata = await self.load_cookie(type="chain")
291+
account_any = (await self.stubAuth.Account(
292+
auth_query.QueryAccountRequest.__call__(address=address), metadata=metadata
293+
)).account
294+
account = account_pb2.EthAccount()
279295
if account_any.Is(account.DESCRIPTOR):
280296
account_any.Unpack(account)
281-
return account
282-
except:
297+
self.number = int(account.base_account.account_number)
298+
self.sequence = int(account.base_account.sequence)
299+
except Exception as e:
300+
logging.debug("error while fetching sequence and number{}".format(e))
283301
return None
284302

285303
async def get_request_id_by_tx_hash(self, tx_hash: bytes) -> List[int]:
@@ -302,7 +320,7 @@ async def get_request_id_by_tx_hash(self, tx_hash: bytes) -> List[int]:
302320
return request_ids
303321

304322
async def simulate_tx(
305-
self, tx_byte: bytes
323+
self, tx_byte: bytes
306324
) -> Tuple[Union[abci_type.SimulationResponse, grpc.RpcError], bool]:
307325
try:
308326
req = tx_service.SimulateRequest(tx_bytes=tx_byte)
@@ -538,19 +556,19 @@ async def get_rewards(self, **kwargs):
538556
# OracleRPC
539557

540558
async def stream_oracle_prices(
541-
self, base_symbol: str, quote_symbol: str, oracle_type: str
559+
self, base_symbol: str, quote_symbol: str, oracle_type: str
542560
):
543561
req = oracle_rpc_pb.StreamPricesRequest(
544562
base_symbol=base_symbol, quote_symbol=quote_symbol, oracle_type=oracle_type
545563
)
546564
return self.stubOracle.StreamPrices(req)
547565

548566
async def get_oracle_prices(
549-
self,
550-
base_symbol: str,
551-
quote_symbol: str,
552-
oracle_type: str,
553-
oracle_scale_factor: int,
567+
self,
568+
base_symbol: str,
569+
quote_symbol: str,
570+
oracle_type: str,
571+
oracle_scale_factor: int,
554572
):
555573
req = oracle_rpc_pb.PriceRequest(
556574
base_symbol=base_symbol,

0 commit comments

Comments
 (0)