diff --git a/chia/_tests/cmds/wallet/test_wallet.py b/chia/_tests/cmds/wallet/test_wallet.py index 38a7f2e35ba8..5efd45679f1c 100644 --- a/chia/_tests/cmds/wallet/test_wallet.py +++ b/chia/_tests/cmds/wallet/test_wallet.py @@ -9,7 +9,7 @@ import pytest from chia_rs import Coin, G2Element from chia_rs.sized_bytes import bytes32 -from chia_rs.sized_ints import uint8, uint16, uint32, uint64 +from chia_rs.sized_ints import uint8, uint16, uint32, uint64, uint128 from click.testing import CliRunner from chia._tests.cmds.cmd_test_utils import TestRpcClients, TestWalletRpcClient, logType, run_cli_command_and_assert @@ -41,11 +41,14 @@ from chia.wallet.util.wallet_types import WalletType from chia.wallet.wallet_coin_store import GetCoinRecords from chia.wallet.wallet_request_types import ( + BalanceResponse, CancelOfferResponse, CATSpendResponse, CreateOfferForIDsResponse, FungibleAsset, GetHeightInfoResponse, + GetWalletBalance, + GetWalletBalanceResponse, GetWallets, GetWalletsResponse, NFTCalculateRoyalties, @@ -242,19 +245,23 @@ async def get_height_info(self) -> GetHeightInfoResponse: self.add_to_log("get_height_info", ()) return GetHeightInfoResponse(uint32(10)) - async def get_wallet_balance(self, wallet_id: int) -> dict[str, uint64]: - self.add_to_log("get_wallet_balance", (wallet_id,)) - if wallet_id == 1: - amount = uint64(1000000000) - elif wallet_id == 2: - amount = uint64(2000000000) + async def get_wallet_balance(self, request: GetWalletBalance) -> GetWalletBalanceResponse: + self.add_to_log("get_wallet_balance", (request,)) + if request.wallet_id == 1: + amount = uint128(1000000000) + elif request.wallet_id == 2: + amount = uint128(2000000000) else: - amount = uint64(1) - return { - "confirmed_wallet_balance": amount, - "spendable_balance": amount, - "unconfirmed_wallet_balance": uint64(0), - } + amount = uint128(1) + return GetWalletBalanceResponse( + BalanceResponse( + wallet_id=request.wallet_id, + wallet_type=uint8(0), # Doesn't matter + confirmed_wallet_balance=amount, + spendable_balance=amount, + unconfirmed_wallet_balance=uint128(0), + ) + ) async def get_nft_wallet_did(self, request: NFTGetWalletDID) -> NFTGetWalletDIDResponse: self.add_to_log("get_nft_wallet_did", (request.wallet_id,)) @@ -307,7 +314,12 @@ async def get_connections( ], "get_sync_status": [(), ()], "get_height_info": [(), ()], - "get_wallet_balance": [(1,), (2,), (3,), (2,)], + "get_wallet_balance": [ + (GetWalletBalance(wallet_id=uint32(1)),), + (GetWalletBalance(wallet_id=uint32(2)),), + (GetWalletBalance(wallet_id=uint32(3)),), + (GetWalletBalance(wallet_id=uint32(2)),), + ], "get_nft_wallet_did": [(3,)], "get_connections": [(None,), (None,)], } diff --git a/chia/_tests/environments/wallet.py b/chia/_tests/environments/wallet.py index b964a03a1ef6..df38388c2102 100644 --- a/chia/_tests/environments/wallet.py +++ b/chia/_tests/environments/wallet.py @@ -25,6 +25,7 @@ from chia.wallet.wallet import Wallet from chia.wallet.wallet_node import Balance, WalletNode from chia.wallet.wallet_node_api import WalletNodeAPI +from chia.wallet.wallet_request_types import GetWalletBalance from chia.wallet.wallet_rpc_api import WalletRpcApi from chia.wallet.wallet_rpc_client import WalletRpcClient from chia.wallet.wallet_state_manager import WalletStateManager @@ -169,7 +170,9 @@ async def check_balances(self, additional_balance_info: dict[Union[int, str], di else {} ), } - balance_response: dict[str, int] = await self.rpc_client.get_wallet_balance(wallet_id) + balance_response: dict[str, int] = ( + await self.rpc_client.get_wallet_balance(GetWalletBalance(wallet_id)) + ).wallet_balance.to_json_dict() if not expected_result.items() <= balance_response.items(): for key, value in expected_result.items(): diff --git a/chia/_tests/pools/test_pool_rpc.py b/chia/_tests/pools/test_pool_rpc.py index 8e7d2460fe41..2c26382dd345 100644 --- a/chia/_tests/pools/test_pool_rpc.py +++ b/chia/_tests/pools/test_pool_rpc.py @@ -41,7 +41,14 @@ from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG from chia.wallet.util.wallet_types import WalletType from chia.wallet.wallet_node import WalletNode -from chia.wallet.wallet_request_types import GetWallets, PWAbsorbRewards, PWJoinPool, PWSelfPool, PWStatus +from chia.wallet.wallet_request_types import ( + GetWalletBalance, + GetWallets, + PWAbsorbRewards, + PWJoinPool, + PWSelfPool, + PWStatus, +) from chia.wallet.wallet_rpc_client import WalletRpcClient from chia.wallet.wallet_state_manager import WalletStateManager @@ -468,8 +475,8 @@ def mempool_empty() -> bool: assert len(asset_id) > 0 await full_node_api.process_all_wallet_transactions(wallet=wallet) await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20) - bal_0 = await client.get_wallet_balance(cat_0_id) - assert bal_0["confirmed_wallet_balance"] == 20 + bal_0 = (await client.get_wallet_balance(GetWalletBalance(cat_0_id))).wallet_balance + assert bal_0.confirmed_wallet_balance == 20 # Test creation of many pool wallets. Use untrusted since that is the more complicated protocol, but don't # run this code more than once, since it's slow. @@ -535,8 +542,8 @@ async def test_absorb_self( await add_blocks_in_batches(blocks[-3:], full_node_api.full_node) await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20) - bal = await client.get_wallet_balance(2) - assert bal["confirmed_wallet_balance"] == 2 * 1_750_000_000_000 + bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance + assert bal.confirmed_wallet_balance == 2 * 1_750_000_000_000 # Claim 2 * 1.75, and farm a new 1.75 absorb_txs = ( @@ -561,8 +568,8 @@ async def test_absorb_self( new_status: PoolWalletInfo = (await client.pw_status(PWStatus(uint32(2)))).state assert status.current == new_status.current assert status.tip_singleton_coin_id != new_status.tip_singleton_coin_id - bal = await client.get_wallet_balance(2) - assert bal["confirmed_wallet_balance"] == 1 * 1_750_000_000_000 + bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance + assert bal.confirmed_wallet_balance == 1 * 1_750_000_000_000 # Claim another 1.75 absorb_txs1 = ( @@ -575,8 +582,8 @@ async def test_absorb_self( await full_node_api.farm_blocks_to_puzzlehash(count=2, farm_to=our_ph, guarantee_transaction_blocks=True) await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20) - bal = await client.get_wallet_balance(2) - assert bal["confirmed_wallet_balance"] == 0 + bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance + assert bal.confirmed_wallet_balance == 0 assert len(await wallet_node.wallet_state_manager.tx_store.get_unconfirmed_for_wallet(2)) == 0 @@ -590,8 +597,8 @@ async def test_absorb_self( await full_node_api.farm_blocks_to_puzzlehash(count=2, farm_to=our_ph, guarantee_transaction_blocks=True) # Balance ignores non coinbase TX - bal = await client.get_wallet_balance(2) - assert bal["confirmed_wallet_balance"] == 0 + bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance + assert bal.confirmed_wallet_balance == 0 with pytest.raises(ValueError): await client.pw_absorb_rewards( @@ -626,8 +633,8 @@ async def test_absorb_self_multiple_coins( pool_expected_confirmed_balance = 0 await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20) - main_bal = await client.get_wallet_balance(1) - assert main_bal["confirmed_wallet_balance"] == main_expected_confirmed_balance + main_bal = (await client.get_wallet_balance(GetWalletBalance(uint32(1)))).wallet_balance + assert main_bal.confirmed_wallet_balance == main_expected_confirmed_balance status: PoolWalletInfo = (await client.pw_status(PWStatus(uint32(2)))).state assert status.current.state == PoolSingletonState.SELF_POOLING.value @@ -650,10 +657,10 @@ async def test_absorb_self_multiple_coins( pool_expected_confirmed_balance += block_count * 1_750_000_000_000 main_expected_confirmed_balance += block_count * 250_000_000_000 - main_bal = await client.get_wallet_balance(1) - assert main_bal["confirmed_wallet_balance"] == main_expected_confirmed_balance - bal = await client.get_wallet_balance(2) - assert bal["confirmed_wallet_balance"] == pool_expected_confirmed_balance + main_bal = (await client.get_wallet_balance(GetWalletBalance(uint32(1)))).wallet_balance + assert main_bal.confirmed_wallet_balance == main_expected_confirmed_balance + bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance + assert bal.confirmed_wallet_balance == pool_expected_confirmed_balance # Claim absorb_txs = ( @@ -671,10 +678,10 @@ async def test_absorb_self_multiple_coins( new_status: PoolWalletInfo = (await client.pw_status(PWStatus(uint32(2)))).state assert status.current == new_status.current assert status.tip_singleton_coin_id != new_status.tip_singleton_coin_id - main_bal = await client.get_wallet_balance(1) - pool_bal = await client.get_wallet_balance(2) - assert pool_bal["confirmed_wallet_balance"] == pool_expected_confirmed_balance - assert main_bal["confirmed_wallet_balance"] == main_expected_confirmed_balance # 10499999999999 + main_bal = (await client.get_wallet_balance(GetWalletBalance(uint32(1)))).wallet_balance + pool_bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance + assert pool_bal.confirmed_wallet_balance == pool_expected_confirmed_balance + assert main_bal.confirmed_wallet_balance == main_expected_confirmed_balance # 10499999999999 @pytest.mark.anyio async def test_absorb_pooling( @@ -726,8 +733,8 @@ async def farming_to_pool() -> bool: await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20) # Pooled plots don't have balance main_expected_confirmed_balance += block_count * 250_000_000_000 - bal = await client.get_wallet_balance(2) - assert bal["confirmed_wallet_balance"] == 0 + bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance + assert bal.confirmed_wallet_balance == 0 # Claim block_count * 1.75 ret = await client.pw_absorb_rewards( @@ -751,12 +758,12 @@ async def status_updated() -> bool: await time_out_assert(20, status_updated) new_status = (await client.pw_status(PWStatus(uint32(2)))).state - bal = await client.get_wallet_balance(2) - assert bal["confirmed_wallet_balance"] == 0 + bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance + assert bal.confirmed_wallet_balance == 0 await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20) - bal = await client.get_wallet_balance(2) - assert bal["confirmed_wallet_balance"] == 0 + bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance + assert bal.confirmed_wallet_balance == 0 assert len(await wallet_node.wallet_state_manager.tx_store.get_unconfirmed_for_wallet(2)) == 0 peak = full_node_api.full_node.blockchain.get_peak() assert peak is not None @@ -798,8 +805,8 @@ async def status_updated() -> bool: status = (await client.pw_status(PWStatus(uint32(2)))).state assert ret.fee_transaction is None - bal2 = await client.get_wallet_balance(2) - assert bal2["confirmed_wallet_balance"] == 0 + bal2 = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance + assert bal2.confirmed_wallet_balance == 0 @pytest.mark.anyio async def test_self_pooling_to_pooling(self, setup: Setup, fee: uint64, self_hostname: str) -> None: diff --git a/chia/_tests/wallet/rpc/test_wallet_rpc.py b/chia/_tests/wallet/rpc/test_wallet_rpc.py index c6ff74abd7c8..ec60209c040b 100644 --- a/chia/_tests/wallet/rpc/test_wallet_rpc.py +++ b/chia/_tests/wallet/rpc/test_wallet_rpc.py @@ -8,7 +8,7 @@ import random from collections.abc import AsyncIterator from operator import attrgetter -from typing import Any, Optional, cast +from typing import Any, Optional from unittest.mock import patch import aiosqlite @@ -123,6 +123,8 @@ GetPrivateKey, GetSyncStatusResponse, GetTimestampForHeight, + GetWalletBalance, + GetWalletBalances, GetWallets, LogIn, NFTCalculateRoyalties, @@ -191,7 +193,9 @@ async def farm_transaction( async def generate_funds(full_node_api: FullNodeSimulator, wallet_bundle: WalletBundle, num_blocks: int = 1) -> int: wallet_id = 1 - initial_balances = await wallet_bundle.rpc_client.get_wallet_balance(wallet_id) + initial_balances = ( + await wallet_bundle.rpc_client.get_wallet_balance(GetWalletBalance(uint32(wallet_id))) + ).wallet_balance ph: bytes32 = decode_puzzle_hash(await wallet_bundle.rpc_client.get_next_address(wallet_id, True)) generated_funds = 0 for _ in range(num_blocks): @@ -203,8 +207,8 @@ async def generate_funds(full_node_api: FullNodeSimulator, wallet_bundle: Wallet # Farm a dummy block to confirm the created funds await farm_transaction_block(full_node_api, wallet_bundle.node) - expected_confirmed = initial_balances["confirmed_wallet_balance"] + generated_funds - expected_unconfirmed = initial_balances["unconfirmed_wallet_balance"] + generated_funds + expected_confirmed = initial_balances.confirmed_wallet_balance + generated_funds + expected_unconfirmed = initial_balances.unconfirmed_wallet_balance + generated_funds await time_out_assert(20, get_confirmed_balance, expected_confirmed, wallet_bundle.rpc_client, wallet_id) await time_out_assert(20, get_unconfirmed_balance, expected_unconfirmed, wallet_bundle.rpc_client, wallet_id) await time_out_assert(20, check_client_synced, True, wallet_bundle.rpc_client) @@ -326,13 +330,18 @@ async def assert_push_tx_error(node_rpc: FullNodeRpcClient, tx: TransactionRecor async def assert_get_balance(rpc_client: WalletRpcClient, wallet_node: WalletNode, wallet: WalletProtocol[Any]) -> None: expected_balance = await wallet_node.get_balance(wallet.id()) expected_balance_dict = expected_balance.to_json_dict() + expected_balance_dict.setdefault("pending_approval_balance", None) expected_balance_dict["wallet_id"] = wallet.id() expected_balance_dict["wallet_type"] = wallet.type() expected_balance_dict["fingerprint"] = wallet_node.logged_in_fingerprint if wallet.type() in {WalletType.CAT, WalletType.CRCAT}: assert isinstance(wallet, CATWallet) - expected_balance_dict["asset_id"] = wallet.get_asset_id() - assert await rpc_client.get_wallet_balance(wallet.id()) == expected_balance_dict + expected_balance_dict["asset_id"] = "0x" + wallet.get_asset_id() + else: + expected_balance_dict["asset_id"] = None + assert ( + await rpc_client.get_wallet_balance(GetWalletBalance(wallet.id())) + ).wallet_balance.to_json_dict() == expected_balance_dict async def tx_in_mempool(client: WalletRpcClient, transaction_id: bytes32) -> bool: @@ -341,15 +350,15 @@ async def tx_in_mempool(client: WalletRpcClient, transaction_id: bytes32) -> boo async def get_confirmed_balance(client: WalletRpcClient, wallet_id: int) -> uint128: - balance = await client.get_wallet_balance(wallet_id) - # TODO: casting due to lack of type checked deserialization - return cast(uint128, balance["confirmed_wallet_balance"]) + return ( + await client.get_wallet_balance(GetWalletBalance(uint32(wallet_id))) + ).wallet_balance.confirmed_wallet_balance async def get_unconfirmed_balance(client: WalletRpcClient, wallet_id: int) -> uint128: - balance = await client.get_wallet_balance(wallet_id) - # TODO: casting due to lack of type checked deserialization - return cast(uint128, balance["unconfirmed_wallet_balance"]) + return ( + await client.get_wallet_balance(GetWalletBalance(uint32(wallet_id))) + ).wallet_balance.unconfirmed_wallet_balance @pytest.mark.anyio @@ -1133,13 +1142,13 @@ async def test_cat_endpoints(wallet_environments: WalletTestFramework, wallet_ty "cat1", ) - cat_0_id = env_0.wallet_aliases["cat0"] + cat_0_id = uint32(env_0.wallet_aliases["cat0"]) # The RPC response contains more than just the balance info but all the # balance info should match. We're leveraging the `<=` operator to check # for subset on `dict` `.items()`. assert ( env_0.wallet_states[uint32(env_0.wallet_aliases["cat0"])].balance.to_json_dict().items() - <= (await env_0.rpc_client.get_wallet_balance(cat_0_id)).items() + <= (await env_0.rpc_client.get_wallet_balance(GetWalletBalance(cat_0_id))).wallet_balance.to_json_dict().items() ) asset_id = await env_0.rpc_client.get_cat_asset_id(cat_0_id) assert (await env_0.rpc_client.get_cat_name(cat_0_id)) == wallet_type.default_wallet_name_for_unknown_cat( @@ -2807,15 +2816,15 @@ async def test_get_balances(wallet_rpc_environment: WalletRpcTestEnvironment) -> await time_out_assert(5, check_mempool_spend_count, True, full_node_api, 2) await farm_transaction_block(full_node_api, wallet_node) await time_out_assert(20, check_client_synced, True, client) - bal = await client.get_wallet_balances() - assert len(bal) == 3 - assert bal["1"]["confirmed_wallet_balance"] == 1999999999880 - assert bal["2"]["confirmed_wallet_balance"] == 100 - assert bal["3"]["confirmed_wallet_balance"] == 20 - bal_ids = await client.get_wallet_balances([3, 2]) - assert len(bal_ids) == 2 - assert bal["2"]["confirmed_wallet_balance"] == 100 - assert bal["3"]["confirmed_wallet_balance"] == 20 + bals_response = await client.get_wallet_balances(GetWalletBalances()) + assert len(bals_response.wallet_balances) == 3 + assert bals_response.wallet_balances[uint32(1)].confirmed_wallet_balance == 1999999999880 + assert bals_response.wallet_balances[uint32(2)].confirmed_wallet_balance == 100 + assert bals_response.wallet_balances[uint32(3)].confirmed_wallet_balance == 20 + bals_response = await client.get_wallet_balances(GetWalletBalances([uint32(3), uint32(2)])) + assert len(bals_response.wallet_balances) == 2 + assert bals_response.wallet_balances[uint32(2)].confirmed_wallet_balance == 100 + assert bals_response.wallet_balances[uint32(3)].confirmed_wallet_balance == 20 @pytest.mark.parametrize( diff --git a/chia/cmds/plotnft_funcs.py b/chia/cmds/plotnft_funcs.py index 329b59a6fdac..09204896bed1 100644 --- a/chia/cmds/plotnft_funcs.py +++ b/chia/cmds/plotnft_funcs.py @@ -43,6 +43,7 @@ from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG from chia.wallet.util.wallet_types import WalletType from chia.wallet.wallet_request_types import ( + GetWalletBalance, GetWallets, PWAbsorbRewards, PWJoinPool, @@ -161,8 +162,8 @@ async def pprint_pool_wallet_state( print(f"Target state: {PoolSingletonState(pool_wallet_info.target.state).name}") print(f"Target pool URL: {pool_wallet_info.target.pool_url}") if pool_wallet_info.current.state == PoolSingletonState.SELF_POOLING.value: - balances: dict[str, Any] = await wallet_client.get_wallet_balance(wallet_id) - balance = balances["confirmed_wallet_balance"] + balances = (await wallet_client.get_wallet_balance(GetWalletBalance(uint32(wallet_id)))).wallet_balance + balance = balances.confirmed_wallet_balance typ = WalletType(int(WalletType.POOLING_WALLET)) address_prefix, scale = wallet_coin_unit(typ, address_prefix) print(f"Claimable balance: {print_balance(balance, scale, address_prefix)}") diff --git a/chia/cmds/wallet_funcs.py b/chia/cmds/wallet_funcs.py index 9cc877999db8..21e0749e8155 100644 --- a/chia/cmds/wallet_funcs.py +++ b/chia/cmds/wallet_funcs.py @@ -54,6 +54,7 @@ DIDUpdateMetadata, FungibleAsset, GetNotifications, + GetWalletBalance, GetWallets, NFTAddURI, NFTCalculateRoyalties, @@ -939,14 +940,14 @@ async def print_balances( # A future RPC update may split them apart, but for now we'll show the first 32 bytes (64 chars) asset_id = summary.data[:64] wallet_id = summary.id - balances = await wallet_client.get_wallet_balance(wallet_id) + balances = (await wallet_client.get_wallet_balance(GetWalletBalance(uint32(wallet_id)))).wallet_balance typ = WalletType(int(summary.type)) address_prefix, scale = wallet_coin_unit(typ, address_prefix) - total_balance: str = print_balance(balances["confirmed_wallet_balance"], scale, address_prefix) + total_balance: str = print_balance(balances.confirmed_wallet_balance, scale, address_prefix) unconfirmed_wallet_balance: str = print_balance( - balances["unconfirmed_wallet_balance"], scale, address_prefix + balances.unconfirmed_wallet_balance, scale, address_prefix ) - spendable_balance: str = print_balance(balances["spendable_balance"], scale, address_prefix) + spendable_balance: str = print_balance(balances.spendable_balance, scale, address_prefix) my_did: Optional[str] = None ljust = 23 if typ == WalletType.CRCAT: @@ -955,9 +956,10 @@ async def print_balances( print(f"{summary.name}:") print(f"{indent}{'-Total Balance:'.ljust(ljust)} {total_balance}") if typ == WalletType.CRCAT: + assert balances.pending_approval_balance is not None print( f"{indent}{'-Balance Pending VC Approval:'.ljust(ljust)} " - f"{print_balance(balances['pending_approval_balance'], scale, address_prefix)}" + f"{print_balance(balances.pending_approval_balance, scale, address_prefix)}" ) print(f"{indent}{'-Pending Total Balance:'.ljust(ljust)} {unconfirmed_wallet_balance}") print(f"{indent}{'-Spendable:'.ljust(ljust)} {spendable_balance}") diff --git a/chia/wallet/wallet_request_types.py b/chia/wallet/wallet_request_types.py index 7a3d18e4b97b..41ead6d8f618 100644 --- a/chia/wallet/wallet_request_types.py +++ b/chia/wallet/wallet_request_types.py @@ -6,7 +6,7 @@ from chia_rs import Coin, G1Element, G2Element, PrivateKey from chia_rs.sized_bytes import bytes32 -from chia_rs.sized_ints import uint16, uint32, uint64 +from chia_rs.sized_ints import uint8, uint16, uint32, uint64 from typing_extensions import Self, dataclass_transform from chia.data_layer.data_layer_wallet import Mirror @@ -32,6 +32,7 @@ from chia.wallet.util.tx_config import TXConfig from chia.wallet.vc_wallet.vc_store import VCProofs, VCRecord from chia.wallet.wallet_info import WalletInfo +from chia.wallet.wallet_node import Balance from chia.wallet.wallet_spend_bundle import WalletSpendBundle @@ -221,6 +222,41 @@ class GetWalletsResponse(Streamable): fingerprint: Optional[uint32] = None +@streamable +@dataclass(frozen=True) +class GetWalletBalance(Streamable): + wallet_id: uint32 + + +@streamable +@dataclass(frozen=True) +class GetWalletBalances(Streamable): + wallet_ids: Optional[list[uint32]] = None + + +# utility for GetWalletBalanceResponse(s) +@streamable +@kw_only_dataclass +class BalanceResponse(Balance): + wallet_id: uint32 = field(default_factory=default_raise) + wallet_type: uint8 = field(default_factory=default_raise) + fingerprint: Optional[uint32] = None + asset_id: Optional[bytes32] = None + pending_approval_balance: Optional[uint64] = None + + +@streamable +@dataclass(frozen=True) +class GetWalletBalanceResponse(Streamable): + wallet_balance: BalanceResponse + + +@streamable +@dataclass(frozen=True) +class GetWalletBalancesResponse(Streamable): + wallet_balances: dict[uint32, BalanceResponse] + + @streamable @dataclass(frozen=True) class GetNotifications(Streamable): diff --git a/chia/wallet/wallet_rpc_api.py b/chia/wallet/wallet_rpc_api.py index b764922c2989..d242222e9116 100644 --- a/chia/wallet/wallet_rpc_api.py +++ b/chia/wallet/wallet_rpc_api.py @@ -111,6 +111,7 @@ AddKeyResponse, ApplySignatures, ApplySignaturesResponse, + BalanceResponse, CheckDeleteKey, CheckDeleteKeyResponse, CombineCoins, @@ -178,6 +179,10 @@ GetSyncStatusResponse, GetTimestampForHeight, GetTimestampForHeightResponse, + GetWalletBalance, + GetWalletBalanceResponse, + GetWalletBalances, + GetWalletBalancesResponse, GetWallets, GetWalletsResponse, LogIn, @@ -1265,7 +1270,7 @@ async def create_new_wallet( # Wallet ########################################################################################## - async def _get_wallet_balance(self, wallet_id: uint32) -> dict[str, Any]: + async def _get_wallet_balance(self, wallet_id: uint32) -> BalanceResponse: wallet = self.service.wallet_state_manager.wallets[wallet_id] balance = await self.service.get_balance(wallet_id) wallet_balance = balance.to_json_dict() @@ -1280,22 +1285,21 @@ async def _get_wallet_balance(self, wallet_id: uint32) -> dict[str, Any]: assert isinstance(wallet, CRCATWallet) wallet_balance["pending_approval_balance"] = await wallet.get_pending_approval_balance() - return wallet_balance + return BalanceResponse.from_json_dict(wallet_balance) - async def get_wallet_balance(self, request: dict[str, Any]) -> EndpointResult: - wallet_id = uint32(request["wallet_id"]) - wallet_balance = await self._get_wallet_balance(wallet_id) - return {"wallet_balance": wallet_balance} + @marshal + async def get_wallet_balance(self, request: GetWalletBalance) -> GetWalletBalanceResponse: + return GetWalletBalanceResponse(await self._get_wallet_balance(request.wallet_id)) - async def get_wallet_balances(self, request: dict[str, Any]) -> EndpointResult: - try: - wallet_ids: list[uint32] = [uint32(wallet_id) for wallet_id in request["wallet_ids"]] - except (TypeError, KeyError): + @marshal + async def get_wallet_balances(self, request: GetWalletBalances) -> GetWalletBalancesResponse: + if request.wallet_ids is not None: + wallet_ids = request.wallet_ids + else: wallet_ids = list(self.service.wallet_state_manager.wallets.keys()) - wallet_balances: dict[uint32, dict[str, Any]] = {} - for wallet_id in wallet_ids: - wallet_balances[wallet_id] = await self._get_wallet_balance(wallet_id) - return {"wallet_balances": wallet_balances} + return GetWalletBalancesResponse( + {wallet_id: await self._get_wallet_balance(wallet_id) for wallet_id in wallet_ids} + ) async def get_transaction(self, request: dict[str, Any]) -> EndpointResult: transaction_id: bytes32 = bytes32.from_hexstr(request["transaction_id"]) diff --git a/chia/wallet/wallet_rpc_client.py b/chia/wallet/wallet_rpc_client.py index 2744a75363f6..17397804e3c6 100644 --- a/chia/wallet/wallet_rpc_client.py +++ b/chia/wallet/wallet_rpc_client.py @@ -100,6 +100,10 @@ GetTimestampForHeightResponse, GetTransactionMemo, GetTransactionMemoResponse, + GetWalletBalance, + GetWalletBalanceResponse, + GetWalletBalances, + GetWalletBalancesResponse, GetWallets, GetWalletsResponse, LogIn, @@ -256,17 +260,11 @@ async def get_wallets(self, request: GetWallets) -> GetWalletsResponse: return GetWalletsResponse.from_json_dict(await self.fetch("get_wallets", request.to_json_dict())) # Wallet APIs - async def get_wallet_balance(self, wallet_id: int) -> dict[str, Any]: - request = {"wallet_id": wallet_id} - response = await self.fetch("get_wallet_balance", request) - # TODO: casting due to lack of type checked deserialization - return cast(dict[str, Any], response["wallet_balance"]) + async def get_wallet_balance(self, request: GetWalletBalance) -> GetWalletBalanceResponse: + return GetWalletBalanceResponse.from_json_dict(await self.fetch("get_wallet_balance", request.to_json_dict())) - async def get_wallet_balances(self, wallet_ids: Optional[list[int]] = None) -> dict[str, dict[str, Any]]: - request = {"wallet_ids": wallet_ids} - response = await self.fetch("get_wallet_balances", request) - # TODO: casting due to lack of type checked deserialization - return cast(dict[str, dict[str, Any]], response["wallet_balances"]) + async def get_wallet_balances(self, request: GetWalletBalances) -> GetWalletBalancesResponse: + return GetWalletBalancesResponse.from_json_dict(await self.fetch("get_wallet_balances", request.to_json_dict())) async def get_transaction(self, transaction_id: bytes32) -> TransactionRecord: request = {"transaction_id": transaction_id.hex()}