Skip to content

Commit 42da90e

Browse files
authored
[CHIA-3294] Port get_wallet_balance(s) to @marshal (#19774)
* Port `get_wallet_balance(s)` * Use Streamable dict functionality
1 parent ec21ca9 commit 42da90e

File tree

9 files changed

+171
-99
lines changed

9 files changed

+171
-99
lines changed

chia/_tests/cmds/wallet/test_wallet.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytest
1010
from chia_rs import Coin, G2Element
1111
from chia_rs.sized_bytes import bytes32
12-
from chia_rs.sized_ints import uint8, uint16, uint32, uint64
12+
from chia_rs.sized_ints import uint8, uint16, uint32, uint64, uint128
1313
from click.testing import CliRunner
1414

1515
from chia._tests.cmds.cmd_test_utils import TestRpcClients, TestWalletRpcClient, logType, run_cli_command_and_assert
@@ -41,11 +41,14 @@
4141
from chia.wallet.util.wallet_types import WalletType
4242
from chia.wallet.wallet_coin_store import GetCoinRecords
4343
from chia.wallet.wallet_request_types import (
44+
BalanceResponse,
4445
CancelOfferResponse,
4546
CATSpendResponse,
4647
CreateOfferForIDsResponse,
4748
FungibleAsset,
4849
GetHeightInfoResponse,
50+
GetWalletBalance,
51+
GetWalletBalanceResponse,
4952
GetWallets,
5053
GetWalletsResponse,
5154
NFTCalculateRoyalties,
@@ -243,19 +246,23 @@ async def get_height_info(self) -> GetHeightInfoResponse:
243246
self.add_to_log("get_height_info", ())
244247
return GetHeightInfoResponse(uint32(10))
245248

246-
async def get_wallet_balance(self, wallet_id: int) -> dict[str, uint64]:
247-
self.add_to_log("get_wallet_balance", (wallet_id,))
248-
if wallet_id == 1:
249-
amount = uint64(1000000000)
250-
elif wallet_id == 2:
251-
amount = uint64(2000000000)
249+
async def get_wallet_balance(self, request: GetWalletBalance) -> GetWalletBalanceResponse:
250+
self.add_to_log("get_wallet_balance", (request,))
251+
if request.wallet_id == 1:
252+
amount = uint128(1000000000)
253+
elif request.wallet_id == 2:
254+
amount = uint128(2000000000)
252255
else:
253-
amount = uint64(1)
254-
return {
255-
"confirmed_wallet_balance": amount,
256-
"spendable_balance": amount,
257-
"unconfirmed_wallet_balance": uint64(0),
258-
}
256+
amount = uint128(1)
257+
return GetWalletBalanceResponse(
258+
BalanceResponse(
259+
wallet_id=request.wallet_id,
260+
wallet_type=uint8(0), # Doesn't matter
261+
confirmed_wallet_balance=amount,
262+
spendable_balance=amount,
263+
unconfirmed_wallet_balance=uint128(0),
264+
)
265+
)
259266

260267
async def get_nft_wallet_did(self, request: NFTGetWalletDID) -> NFTGetWalletDIDResponse:
261268
self.add_to_log("get_nft_wallet_did", (request.wallet_id,))
@@ -308,7 +315,12 @@ async def get_connections(
308315
],
309316
"get_sync_status": [(), ()],
310317
"get_height_info": [(), ()],
311-
"get_wallet_balance": [(1,), (2,), (3,), (2,)],
318+
"get_wallet_balance": [
319+
(GetWalletBalance(wallet_id=uint32(1)),),
320+
(GetWalletBalance(wallet_id=uint32(2)),),
321+
(GetWalletBalance(wallet_id=uint32(3)),),
322+
(GetWalletBalance(wallet_id=uint32(2)),),
323+
],
312324
"get_nft_wallet_did": [(3,)],
313325
"get_connections": [(None,), (None,)],
314326
}

chia/_tests/environments/wallet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from chia.wallet.wallet import Wallet
2626
from chia.wallet.wallet_node import Balance, WalletNode
2727
from chia.wallet.wallet_node_api import WalletNodeAPI
28+
from chia.wallet.wallet_request_types import GetWalletBalance
2829
from chia.wallet.wallet_rpc_api import WalletRpcApi
2930
from chia.wallet.wallet_rpc_client import WalletRpcClient
3031
from chia.wallet.wallet_state_manager import WalletStateManager
@@ -169,7 +170,9 @@ async def check_balances(self, additional_balance_info: dict[Union[int, str], di
169170
else {}
170171
),
171172
}
172-
balance_response: dict[str, int] = await self.rpc_client.get_wallet_balance(wallet_id)
173+
balance_response: dict[str, int] = (
174+
await self.rpc_client.get_wallet_balance(GetWalletBalance(wallet_id))
175+
).wallet_balance.to_json_dict()
173176

174177
if not expected_result.items() <= balance_response.items():
175178
for key, value in expected_result.items():

chia/_tests/pools/test_pool_rpc.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,14 @@
4141
from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG
4242
from chia.wallet.util.wallet_types import WalletType
4343
from chia.wallet.wallet_node import WalletNode
44-
from chia.wallet.wallet_request_types import GetWallets, PWAbsorbRewards, PWJoinPool, PWSelfPool, PWStatus
44+
from chia.wallet.wallet_request_types import (
45+
GetWalletBalance,
46+
GetWallets,
47+
PWAbsorbRewards,
48+
PWJoinPool,
49+
PWSelfPool,
50+
PWStatus,
51+
)
4552
from chia.wallet.wallet_rpc_client import WalletRpcClient
4653
from chia.wallet.wallet_state_manager import WalletStateManager
4754

@@ -468,8 +475,8 @@ def mempool_empty() -> bool:
468475
assert len(asset_id) > 0
469476
await full_node_api.process_all_wallet_transactions(wallet=wallet)
470477
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
471-
bal_0 = await client.get_wallet_balance(cat_0_id)
472-
assert bal_0["confirmed_wallet_balance"] == 20
478+
bal_0 = (await client.get_wallet_balance(GetWalletBalance(cat_0_id))).wallet_balance
479+
assert bal_0.confirmed_wallet_balance == 20
473480

474481
# Test creation of many pool wallets. Use untrusted since that is the more complicated protocol, but don't
475482
# run this code more than once, since it's slow.
@@ -535,8 +542,8 @@ async def test_absorb_self(
535542
await add_blocks_in_batches(blocks[-3:], full_node_api.full_node)
536543
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
537544

538-
bal = await client.get_wallet_balance(2)
539-
assert bal["confirmed_wallet_balance"] == 2 * 1_750_000_000_000
545+
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
546+
assert bal.confirmed_wallet_balance == 2 * 1_750_000_000_000
540547

541548
# Claim 2 * 1.75, and farm a new 1.75
542549
absorb_txs = (
@@ -561,8 +568,8 @@ async def test_absorb_self(
561568
new_status: PoolWalletInfo = (await client.pw_status(PWStatus(uint32(2)))).state
562569
assert status.current == new_status.current
563570
assert status.tip_singleton_coin_id != new_status.tip_singleton_coin_id
564-
bal = await client.get_wallet_balance(2)
565-
assert bal["confirmed_wallet_balance"] == 1 * 1_750_000_000_000
571+
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
572+
assert bal.confirmed_wallet_balance == 1 * 1_750_000_000_000
566573

567574
# Claim another 1.75
568575
absorb_txs1 = (
@@ -575,8 +582,8 @@ async def test_absorb_self(
575582

576583
await full_node_api.farm_blocks_to_puzzlehash(count=2, farm_to=our_ph, guarantee_transaction_blocks=True)
577584
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
578-
bal = await client.get_wallet_balance(2)
579-
assert bal["confirmed_wallet_balance"] == 0
585+
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
586+
assert bal.confirmed_wallet_balance == 0
580587

581588
assert len(await wallet_node.wallet_state_manager.tx_store.get_unconfirmed_for_wallet(2)) == 0
582589

@@ -590,8 +597,8 @@ async def test_absorb_self(
590597
await full_node_api.farm_blocks_to_puzzlehash(count=2, farm_to=our_ph, guarantee_transaction_blocks=True)
591598

592599
# Balance ignores non coinbase TX
593-
bal = await client.get_wallet_balance(2)
594-
assert bal["confirmed_wallet_balance"] == 0
600+
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
601+
assert bal.confirmed_wallet_balance == 0
595602

596603
with pytest.raises(ValueError):
597604
await client.pw_absorb_rewards(
@@ -626,8 +633,8 @@ async def test_absorb_self_multiple_coins(
626633
pool_expected_confirmed_balance = 0
627634

628635
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
629-
main_bal = await client.get_wallet_balance(1)
630-
assert main_bal["confirmed_wallet_balance"] == main_expected_confirmed_balance
636+
main_bal = (await client.get_wallet_balance(GetWalletBalance(uint32(1)))).wallet_balance
637+
assert main_bal.confirmed_wallet_balance == main_expected_confirmed_balance
631638

632639
status: PoolWalletInfo = (await client.pw_status(PWStatus(uint32(2)))).state
633640
assert status.current.state == PoolSingletonState.SELF_POOLING.value
@@ -650,10 +657,10 @@ async def test_absorb_self_multiple_coins(
650657
pool_expected_confirmed_balance += block_count * 1_750_000_000_000
651658
main_expected_confirmed_balance += block_count * 250_000_000_000
652659

653-
main_bal = await client.get_wallet_balance(1)
654-
assert main_bal["confirmed_wallet_balance"] == main_expected_confirmed_balance
655-
bal = await client.get_wallet_balance(2)
656-
assert bal["confirmed_wallet_balance"] == pool_expected_confirmed_balance
660+
main_bal = (await client.get_wallet_balance(GetWalletBalance(uint32(1)))).wallet_balance
661+
assert main_bal.confirmed_wallet_balance == main_expected_confirmed_balance
662+
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
663+
assert bal.confirmed_wallet_balance == pool_expected_confirmed_balance
657664

658665
# Claim
659666
absorb_txs = (
@@ -671,10 +678,10 @@ async def test_absorb_self_multiple_coins(
671678
new_status: PoolWalletInfo = (await client.pw_status(PWStatus(uint32(2)))).state
672679
assert status.current == new_status.current
673680
assert status.tip_singleton_coin_id != new_status.tip_singleton_coin_id
674-
main_bal = await client.get_wallet_balance(1)
675-
pool_bal = await client.get_wallet_balance(2)
676-
assert pool_bal["confirmed_wallet_balance"] == pool_expected_confirmed_balance
677-
assert main_bal["confirmed_wallet_balance"] == main_expected_confirmed_balance # 10499999999999
681+
main_bal = (await client.get_wallet_balance(GetWalletBalance(uint32(1)))).wallet_balance
682+
pool_bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
683+
assert pool_bal.confirmed_wallet_balance == pool_expected_confirmed_balance
684+
assert main_bal.confirmed_wallet_balance == main_expected_confirmed_balance # 10499999999999
678685

679686
@pytest.mark.anyio
680687
async def test_absorb_pooling(
@@ -726,8 +733,8 @@ async def farming_to_pool() -> bool:
726733
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
727734
# Pooled plots don't have balance
728735
main_expected_confirmed_balance += block_count * 250_000_000_000
729-
bal = await client.get_wallet_balance(2)
730-
assert bal["confirmed_wallet_balance"] == 0
736+
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
737+
assert bal.confirmed_wallet_balance == 0
731738

732739
# Claim block_count * 1.75
733740
ret = await client.pw_absorb_rewards(
@@ -751,12 +758,12 @@ async def status_updated() -> bool:
751758

752759
await time_out_assert(20, status_updated)
753760
new_status = (await client.pw_status(PWStatus(uint32(2)))).state
754-
bal = await client.get_wallet_balance(2)
755-
assert bal["confirmed_wallet_balance"] == 0
761+
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
762+
assert bal.confirmed_wallet_balance == 0
756763

757764
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
758-
bal = await client.get_wallet_balance(2)
759-
assert bal["confirmed_wallet_balance"] == 0
765+
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
766+
assert bal.confirmed_wallet_balance == 0
760767
assert len(await wallet_node.wallet_state_manager.tx_store.get_unconfirmed_for_wallet(2)) == 0
761768
peak = full_node_api.full_node.blockchain.get_peak()
762769
assert peak is not None
@@ -798,8 +805,8 @@ async def status_updated() -> bool:
798805
status = (await client.pw_status(PWStatus(uint32(2)))).state
799806
assert ret.fee_transaction is None
800807

801-
bal2 = await client.get_wallet_balance(2)
802-
assert bal2["confirmed_wallet_balance"] == 0
808+
bal2 = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
809+
assert bal2.confirmed_wallet_balance == 0
803810

804811
@pytest.mark.anyio
805812
async def test_self_pooling_to_pooling(self, setup: Setup, fee: uint64, self_hostname: str) -> None:

chia/_tests/wallet/rpc/test_wallet_rpc.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import random
99
from collections.abc import AsyncIterator
1010
from operator import attrgetter
11-
from typing import Any, Optional, cast
11+
from typing import Any, Optional
1212
from unittest.mock import patch
1313

1414
import aiosqlite
@@ -123,6 +123,8 @@
123123
GetPrivateKey,
124124
GetSyncStatusResponse,
125125
GetTimestampForHeight,
126+
GetWalletBalance,
127+
GetWalletBalances,
126128
GetWallets,
127129
LogIn,
128130
NFTCalculateRoyalties,
@@ -191,7 +193,9 @@ async def farm_transaction(
191193

192194
async def generate_funds(full_node_api: FullNodeSimulator, wallet_bundle: WalletBundle, num_blocks: int = 1) -> int:
193195
wallet_id = 1
194-
initial_balances = await wallet_bundle.rpc_client.get_wallet_balance(wallet_id)
196+
initial_balances = (
197+
await wallet_bundle.rpc_client.get_wallet_balance(GetWalletBalance(uint32(wallet_id)))
198+
).wallet_balance
195199
ph: bytes32 = decode_puzzle_hash(await wallet_bundle.rpc_client.get_next_address(wallet_id, True))
196200
generated_funds = 0
197201
for _ in range(num_blocks):
@@ -203,8 +207,8 @@ async def generate_funds(full_node_api: FullNodeSimulator, wallet_bundle: Wallet
203207
# Farm a dummy block to confirm the created funds
204208
await farm_transaction_block(full_node_api, wallet_bundle.node)
205209

206-
expected_confirmed = initial_balances["confirmed_wallet_balance"] + generated_funds
207-
expected_unconfirmed = initial_balances["unconfirmed_wallet_balance"] + generated_funds
210+
expected_confirmed = initial_balances.confirmed_wallet_balance + generated_funds
211+
expected_unconfirmed = initial_balances.unconfirmed_wallet_balance + generated_funds
208212
await time_out_assert(20, get_confirmed_balance, expected_confirmed, wallet_bundle.rpc_client, wallet_id)
209213
await time_out_assert(20, get_unconfirmed_balance, expected_unconfirmed, wallet_bundle.rpc_client, wallet_id)
210214
await time_out_assert(20, check_client_synced, True, wallet_bundle.rpc_client)
@@ -326,13 +330,18 @@ async def assert_push_tx_error(node_rpc: FullNodeRpcClient, tx: TransactionRecor
326330
async def assert_get_balance(rpc_client: WalletRpcClient, wallet_node: WalletNode, wallet: WalletProtocol[Any]) -> None:
327331
expected_balance = await wallet_node.get_balance(wallet.id())
328332
expected_balance_dict = expected_balance.to_json_dict()
333+
expected_balance_dict.setdefault("pending_approval_balance", None)
329334
expected_balance_dict["wallet_id"] = wallet.id()
330335
expected_balance_dict["wallet_type"] = wallet.type()
331336
expected_balance_dict["fingerprint"] = wallet_node.logged_in_fingerprint
332337
if wallet.type() in {WalletType.CAT, WalletType.CRCAT}:
333338
assert isinstance(wallet, CATWallet)
334-
expected_balance_dict["asset_id"] = wallet.get_asset_id()
335-
assert await rpc_client.get_wallet_balance(wallet.id()) == expected_balance_dict
339+
expected_balance_dict["asset_id"] = "0x" + wallet.get_asset_id()
340+
else:
341+
expected_balance_dict["asset_id"] = None
342+
assert (
343+
await rpc_client.get_wallet_balance(GetWalletBalance(wallet.id()))
344+
).wallet_balance.to_json_dict() == expected_balance_dict
336345

337346

338347
async def tx_in_mempool(client: WalletRpcClient, transaction_id: bytes32) -> bool:
@@ -341,15 +350,15 @@ async def tx_in_mempool(client: WalletRpcClient, transaction_id: bytes32) -> boo
341350

342351

343352
async def get_confirmed_balance(client: WalletRpcClient, wallet_id: int) -> uint128:
344-
balance = await client.get_wallet_balance(wallet_id)
345-
# TODO: casting due to lack of type checked deserialization
346-
return cast(uint128, balance["confirmed_wallet_balance"])
353+
return (
354+
await client.get_wallet_balance(GetWalletBalance(uint32(wallet_id)))
355+
).wallet_balance.confirmed_wallet_balance
347356

348357

349358
async def get_unconfirmed_balance(client: WalletRpcClient, wallet_id: int) -> uint128:
350-
balance = await client.get_wallet_balance(wallet_id)
351-
# TODO: casting due to lack of type checked deserialization
352-
return cast(uint128, balance["unconfirmed_wallet_balance"])
359+
return (
360+
await client.get_wallet_balance(GetWalletBalance(uint32(wallet_id)))
361+
).wallet_balance.unconfirmed_wallet_balance
353362

354363

355364
@pytest.mark.anyio
@@ -1131,13 +1140,13 @@ async def test_cat_endpoints(wallet_environments: WalletTestFramework, wallet_ty
11311140
"cat1",
11321141
)
11331142

1134-
cat_0_id = env_0.wallet_aliases["cat0"]
1143+
cat_0_id = uint32(env_0.wallet_aliases["cat0"])
11351144
# The RPC response contains more than just the balance info but all the
11361145
# balance info should match. We're leveraging the `<=` operator to check
11371146
# for subset on `dict` `.items()`.
11381147
assert (
11391148
env_0.wallet_states[uint32(env_0.wallet_aliases["cat0"])].balance.to_json_dict().items()
1140-
<= (await env_0.rpc_client.get_wallet_balance(cat_0_id)).items()
1149+
<= (await env_0.rpc_client.get_wallet_balance(GetWalletBalance(cat_0_id))).wallet_balance.to_json_dict().items()
11411150
)
11421151
asset_id = await env_0.rpc_client.get_cat_asset_id(cat_0_id)
11431152
assert (await env_0.rpc_client.get_cat_name(cat_0_id)) == wallet_type.default_wallet_name_for_unknown_cat(
@@ -3009,15 +3018,15 @@ async def test_get_balances(wallet_rpc_environment: WalletRpcTestEnvironment) ->
30093018
await time_out_assert(5, check_mempool_spend_count, True, full_node_api, 2)
30103019
await farm_transaction_block(full_node_api, wallet_node)
30113020
await time_out_assert(20, check_client_synced, True, client)
3012-
bal = await client.get_wallet_balances()
3013-
assert len(bal) == 3
3014-
assert bal["1"]["confirmed_wallet_balance"] == 1999999999880
3015-
assert bal["2"]["confirmed_wallet_balance"] == 100
3016-
assert bal["3"]["confirmed_wallet_balance"] == 20
3017-
bal_ids = await client.get_wallet_balances([3, 2])
3018-
assert len(bal_ids) == 2
3019-
assert bal["2"]["confirmed_wallet_balance"] == 100
3020-
assert bal["3"]["confirmed_wallet_balance"] == 20
3021+
bals_response = await client.get_wallet_balances(GetWalletBalances())
3022+
assert len(bals_response.wallet_balances) == 3
3023+
assert bals_response.wallet_balances[uint32(1)].confirmed_wallet_balance == 1999999999880
3024+
assert bals_response.wallet_balances[uint32(2)].confirmed_wallet_balance == 100
3025+
assert bals_response.wallet_balances[uint32(3)].confirmed_wallet_balance == 20
3026+
bals_response = await client.get_wallet_balances(GetWalletBalances([uint32(3), uint32(2)]))
3027+
assert len(bals_response.wallet_balances) == 2
3028+
assert bals_response.wallet_balances[uint32(2)].confirmed_wallet_balance == 100
3029+
assert bals_response.wallet_balances[uint32(3)].confirmed_wallet_balance == 20
30213030

30223031

30233032
@pytest.mark.parametrize(

chia/cmds/plotnft_funcs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG
4444
from chia.wallet.util.wallet_types import WalletType
4545
from chia.wallet.wallet_request_types import (
46+
GetWalletBalance,
4647
GetWallets,
4748
PWAbsorbRewards,
4849
PWJoinPool,
@@ -161,8 +162,8 @@ async def pprint_pool_wallet_state(
161162
print(f"Target state: {PoolSingletonState(pool_wallet_info.target.state).name}")
162163
print(f"Target pool URL: {pool_wallet_info.target.pool_url}")
163164
if pool_wallet_info.current.state == PoolSingletonState.SELF_POOLING.value:
164-
balances: dict[str, Any] = await wallet_client.get_wallet_balance(wallet_id)
165-
balance = balances["confirmed_wallet_balance"]
165+
balances = (await wallet_client.get_wallet_balance(GetWalletBalance(uint32(wallet_id)))).wallet_balance
166+
balance = balances.confirmed_wallet_balance
166167
typ = WalletType(int(WalletType.POOLING_WALLET))
167168
address_prefix, scale = wallet_coin_unit(typ, address_prefix)
168169
print(f"Claimable balance: {print_balance(balance, scale, address_prefix)}")

0 commit comments

Comments
 (0)