Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions chia/_tests/cmds/wallet/test_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
from chia_rs import Coin, G2Element
from chia_rs.sized_bytes import bytes32
from chia_rs.sized_ints import uint8, uint16, uint32, uint64
from chia_rs.sized_ints import uint8, uint16, uint32, uint64, uint128
from click.testing import CliRunner

from chia._tests.cmds.cmd_test_utils import TestRpcClients, TestWalletRpcClient, logType, run_cli_command_and_assert
Expand Down Expand Up @@ -41,11 +41,14 @@
from chia.wallet.util.wallet_types import WalletType
from chia.wallet.wallet_coin_store import GetCoinRecords
from chia.wallet.wallet_request_types import (
BalanceResponse,
CancelOfferResponse,
CATSpendResponse,
CreateOfferForIDsResponse,
FungibleAsset,
GetHeightInfoResponse,
GetWalletBalance,
GetWalletBalanceResponse,
GetWallets,
GetWalletsResponse,
NFTCalculateRoyalties,
Expand Down Expand Up @@ -242,19 +245,23 @@ async def get_height_info(self) -> GetHeightInfoResponse:
self.add_to_log("get_height_info", ())
return GetHeightInfoResponse(uint32(10))

async def get_wallet_balance(self, wallet_id: int) -> dict[str, uint64]:
self.add_to_log("get_wallet_balance", (wallet_id,))
if wallet_id == 1:
amount = uint64(1000000000)
elif wallet_id == 2:
amount = uint64(2000000000)
async def get_wallet_balance(self, request: GetWalletBalance) -> GetWalletBalanceResponse:
self.add_to_log("get_wallet_balance", (request,))
if request.wallet_id == 1:
amount = uint128(1000000000)
elif request.wallet_id == 2:
amount = uint128(2000000000)
else:
amount = uint64(1)
return {
"confirmed_wallet_balance": amount,
"spendable_balance": amount,
"unconfirmed_wallet_balance": uint64(0),
}
amount = uint128(1)
return GetWalletBalanceResponse(
BalanceResponse(
wallet_id=request.wallet_id,
wallet_type=uint8(0), # Doesn't matter
confirmed_wallet_balance=amount,
spendable_balance=amount,
unconfirmed_wallet_balance=uint128(0),
)
)

async def get_nft_wallet_did(self, request: NFTGetWalletDID) -> NFTGetWalletDIDResponse:
self.add_to_log("get_nft_wallet_did", (request.wallet_id,))
Expand Down Expand Up @@ -307,7 +314,12 @@ async def get_connections(
],
"get_sync_status": [(), ()],
"get_height_info": [(), ()],
"get_wallet_balance": [(1,), (2,), (3,), (2,)],
"get_wallet_balance": [
(GetWalletBalance(wallet_id=uint32(1)),),
(GetWalletBalance(wallet_id=uint32(2)),),
(GetWalletBalance(wallet_id=uint32(3)),),
(GetWalletBalance(wallet_id=uint32(2)),),
],
"get_nft_wallet_did": [(3,)],
"get_connections": [(None,), (None,)],
}
Expand Down
5 changes: 4 additions & 1 deletion chia/_tests/environments/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from chia.wallet.wallet import Wallet
from chia.wallet.wallet_node import Balance, WalletNode
from chia.wallet.wallet_node_api import WalletNodeAPI
from chia.wallet.wallet_request_types import GetWalletBalance
from chia.wallet.wallet_rpc_api import WalletRpcApi
from chia.wallet.wallet_rpc_client import WalletRpcClient
from chia.wallet.wallet_state_manager import WalletStateManager
Expand Down Expand Up @@ -169,7 +170,9 @@ async def check_balances(self, additional_balance_info: dict[Union[int, str], di
else {}
),
}
balance_response: dict[str, int] = await self.rpc_client.get_wallet_balance(wallet_id)
balance_response: dict[str, int] = (
await self.rpc_client.get_wallet_balance(GetWalletBalance(wallet_id))
).wallet_balance.to_json_dict()

if not expected_result.items() <= balance_response.items():
for key, value in expected_result.items():
Expand Down
65 changes: 36 additions & 29 deletions chia/_tests/pools/test_pool_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,14 @@
from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG
from chia.wallet.util.wallet_types import WalletType
from chia.wallet.wallet_node import WalletNode
from chia.wallet.wallet_request_types import GetWallets, PWAbsorbRewards, PWJoinPool, PWSelfPool, PWStatus
from chia.wallet.wallet_request_types import (
GetWalletBalance,
GetWallets,
PWAbsorbRewards,
PWJoinPool,
PWSelfPool,
PWStatus,
)
from chia.wallet.wallet_rpc_client import WalletRpcClient
from chia.wallet.wallet_state_manager import WalletStateManager

Expand Down Expand Up @@ -468,8 +475,8 @@ def mempool_empty() -> bool:
assert len(asset_id) > 0
await full_node_api.process_all_wallet_transactions(wallet=wallet)
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
bal_0 = await client.get_wallet_balance(cat_0_id)
assert bal_0["confirmed_wallet_balance"] == 20
bal_0 = (await client.get_wallet_balance(GetWalletBalance(cat_0_id))).wallet_balance
assert bal_0.confirmed_wallet_balance == 20

# Test creation of many pool wallets. Use untrusted since that is the more complicated protocol, but don't
# run this code more than once, since it's slow.
Expand Down Expand Up @@ -535,8 +542,8 @@ async def test_absorb_self(
await add_blocks_in_batches(blocks[-3:], full_node_api.full_node)
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)

bal = await client.get_wallet_balance(2)
assert bal["confirmed_wallet_balance"] == 2 * 1_750_000_000_000
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal.confirmed_wallet_balance == 2 * 1_750_000_000_000

# Claim 2 * 1.75, and farm a new 1.75
absorb_txs = (
Expand All @@ -561,8 +568,8 @@ async def test_absorb_self(
new_status: PoolWalletInfo = (await client.pw_status(PWStatus(uint32(2)))).state
assert status.current == new_status.current
assert status.tip_singleton_coin_id != new_status.tip_singleton_coin_id
bal = await client.get_wallet_balance(2)
assert bal["confirmed_wallet_balance"] == 1 * 1_750_000_000_000
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal.confirmed_wallet_balance == 1 * 1_750_000_000_000

# Claim another 1.75
absorb_txs1 = (
Expand All @@ -575,8 +582,8 @@ async def test_absorb_self(

await full_node_api.farm_blocks_to_puzzlehash(count=2, farm_to=our_ph, guarantee_transaction_blocks=True)
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
bal = await client.get_wallet_balance(2)
assert bal["confirmed_wallet_balance"] == 0
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal.confirmed_wallet_balance == 0

assert len(await wallet_node.wallet_state_manager.tx_store.get_unconfirmed_for_wallet(2)) == 0

Expand All @@ -590,8 +597,8 @@ async def test_absorb_self(
await full_node_api.farm_blocks_to_puzzlehash(count=2, farm_to=our_ph, guarantee_transaction_blocks=True)

# Balance ignores non coinbase TX
bal = await client.get_wallet_balance(2)
assert bal["confirmed_wallet_balance"] == 0
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal.confirmed_wallet_balance == 0

with pytest.raises(ValueError):
await client.pw_absorb_rewards(
Expand Down Expand Up @@ -626,8 +633,8 @@ async def test_absorb_self_multiple_coins(
pool_expected_confirmed_balance = 0

await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
main_bal = await client.get_wallet_balance(1)
assert main_bal["confirmed_wallet_balance"] == main_expected_confirmed_balance
main_bal = (await client.get_wallet_balance(GetWalletBalance(uint32(1)))).wallet_balance
assert main_bal.confirmed_wallet_balance == main_expected_confirmed_balance

status: PoolWalletInfo = (await client.pw_status(PWStatus(uint32(2)))).state
assert status.current.state == PoolSingletonState.SELF_POOLING.value
Expand All @@ -650,10 +657,10 @@ async def test_absorb_self_multiple_coins(
pool_expected_confirmed_balance += block_count * 1_750_000_000_000
main_expected_confirmed_balance += block_count * 250_000_000_000

main_bal = await client.get_wallet_balance(1)
assert main_bal["confirmed_wallet_balance"] == main_expected_confirmed_balance
bal = await client.get_wallet_balance(2)
assert bal["confirmed_wallet_balance"] == pool_expected_confirmed_balance
main_bal = (await client.get_wallet_balance(GetWalletBalance(uint32(1)))).wallet_balance
assert main_bal.confirmed_wallet_balance == main_expected_confirmed_balance
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal.confirmed_wallet_balance == pool_expected_confirmed_balance

# Claim
absorb_txs = (
Expand All @@ -671,10 +678,10 @@ async def test_absorb_self_multiple_coins(
new_status: PoolWalletInfo = (await client.pw_status(PWStatus(uint32(2)))).state
assert status.current == new_status.current
assert status.tip_singleton_coin_id != new_status.tip_singleton_coin_id
main_bal = await client.get_wallet_balance(1)
pool_bal = await client.get_wallet_balance(2)
assert pool_bal["confirmed_wallet_balance"] == pool_expected_confirmed_balance
assert main_bal["confirmed_wallet_balance"] == main_expected_confirmed_balance # 10499999999999
main_bal = (await client.get_wallet_balance(GetWalletBalance(uint32(1)))).wallet_balance
pool_bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert pool_bal.confirmed_wallet_balance == pool_expected_confirmed_balance
assert main_bal.confirmed_wallet_balance == main_expected_confirmed_balance # 10499999999999

@pytest.mark.anyio
async def test_absorb_pooling(
Expand Down Expand Up @@ -726,8 +733,8 @@ async def farming_to_pool() -> bool:
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
# Pooled plots don't have balance
main_expected_confirmed_balance += block_count * 250_000_000_000
bal = await client.get_wallet_balance(2)
assert bal["confirmed_wallet_balance"] == 0
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal.confirmed_wallet_balance == 0

# Claim block_count * 1.75
ret = await client.pw_absorb_rewards(
Expand All @@ -751,12 +758,12 @@ async def status_updated() -> bool:

await time_out_assert(20, status_updated)
new_status = (await client.pw_status(PWStatus(uint32(2)))).state
bal = await client.get_wallet_balance(2)
assert bal["confirmed_wallet_balance"] == 0
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal.confirmed_wallet_balance == 0

await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
bal = await client.get_wallet_balance(2)
assert bal["confirmed_wallet_balance"] == 0
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal.confirmed_wallet_balance == 0
assert len(await wallet_node.wallet_state_manager.tx_store.get_unconfirmed_for_wallet(2)) == 0
peak = full_node_api.full_node.blockchain.get_peak()
assert peak is not None
Expand Down Expand Up @@ -798,8 +805,8 @@ async def status_updated() -> bool:
status = (await client.pw_status(PWStatus(uint32(2)))).state
assert ret.fee_transaction is None

bal2 = await client.get_wallet_balance(2)
assert bal2["confirmed_wallet_balance"] == 0
bal2 = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal2.confirmed_wallet_balance == 0

@pytest.mark.anyio
async def test_self_pooling_to_pooling(self, setup: Setup, fee: uint64, self_hostname: str) -> None:
Expand Down
55 changes: 32 additions & 23 deletions chia/_tests/wallet/rpc/test_wallet_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import random
from collections.abc import AsyncIterator
from operator import attrgetter
from typing import Any, Optional, cast
from typing import Any, Optional
from unittest.mock import patch

import aiosqlite
Expand Down Expand Up @@ -123,6 +123,8 @@
GetPrivateKey,
GetSyncStatusResponse,
GetTimestampForHeight,
GetWalletBalance,
GetWalletBalances,
GetWallets,
LogIn,
NFTCalculateRoyalties,
Expand Down Expand Up @@ -191,7 +193,9 @@ async def farm_transaction(

async def generate_funds(full_node_api: FullNodeSimulator, wallet_bundle: WalletBundle, num_blocks: int = 1) -> int:
wallet_id = 1
initial_balances = await wallet_bundle.rpc_client.get_wallet_balance(wallet_id)
initial_balances = (
await wallet_bundle.rpc_client.get_wallet_balance(GetWalletBalance(uint32(wallet_id)))
).wallet_balance
ph: bytes32 = decode_puzzle_hash(await wallet_bundle.rpc_client.get_next_address(wallet_id, True))
generated_funds = 0
for _ in range(num_blocks):
Expand All @@ -203,8 +207,8 @@ async def generate_funds(full_node_api: FullNodeSimulator, wallet_bundle: Wallet
# Farm a dummy block to confirm the created funds
await farm_transaction_block(full_node_api, wallet_bundle.node)

expected_confirmed = initial_balances["confirmed_wallet_balance"] + generated_funds
expected_unconfirmed = initial_balances["unconfirmed_wallet_balance"] + generated_funds
expected_confirmed = initial_balances.confirmed_wallet_balance + generated_funds
expected_unconfirmed = initial_balances.unconfirmed_wallet_balance + generated_funds
await time_out_assert(20, get_confirmed_balance, expected_confirmed, wallet_bundle.rpc_client, wallet_id)
await time_out_assert(20, get_unconfirmed_balance, expected_unconfirmed, wallet_bundle.rpc_client, wallet_id)
await time_out_assert(20, check_client_synced, True, wallet_bundle.rpc_client)
Expand Down Expand Up @@ -326,13 +330,18 @@ async def assert_push_tx_error(node_rpc: FullNodeRpcClient, tx: TransactionRecor
async def assert_get_balance(rpc_client: WalletRpcClient, wallet_node: WalletNode, wallet: WalletProtocol[Any]) -> None:
expected_balance = await wallet_node.get_balance(wallet.id())
expected_balance_dict = expected_balance.to_json_dict()
expected_balance_dict.setdefault("pending_approval_balance", None)
expected_balance_dict["wallet_id"] = wallet.id()
expected_balance_dict["wallet_type"] = wallet.type()
expected_balance_dict["fingerprint"] = wallet_node.logged_in_fingerprint
if wallet.type() in {WalletType.CAT, WalletType.CRCAT}:
assert isinstance(wallet, CATWallet)
expected_balance_dict["asset_id"] = wallet.get_asset_id()
assert await rpc_client.get_wallet_balance(wallet.id()) == expected_balance_dict
expected_balance_dict["asset_id"] = "0x" + wallet.get_asset_id()
else:
expected_balance_dict["asset_id"] = None
assert (
await rpc_client.get_wallet_balance(GetWalletBalance(wallet.id()))
).wallet_balance.to_json_dict() == expected_balance_dict


async def tx_in_mempool(client: WalletRpcClient, transaction_id: bytes32) -> bool:
Expand All @@ -341,15 +350,15 @@ async def tx_in_mempool(client: WalletRpcClient, transaction_id: bytes32) -> boo


async def get_confirmed_balance(client: WalletRpcClient, wallet_id: int) -> uint128:
balance = await client.get_wallet_balance(wallet_id)
# TODO: casting due to lack of type checked deserialization
return cast(uint128, balance["confirmed_wallet_balance"])
return (
await client.get_wallet_balance(GetWalletBalance(uint32(wallet_id)))
).wallet_balance.confirmed_wallet_balance


async def get_unconfirmed_balance(client: WalletRpcClient, wallet_id: int) -> uint128:
balance = await client.get_wallet_balance(wallet_id)
# TODO: casting due to lack of type checked deserialization
return cast(uint128, balance["unconfirmed_wallet_balance"])
return (
await client.get_wallet_balance(GetWalletBalance(uint32(wallet_id)))
).wallet_balance.unconfirmed_wallet_balance


@pytest.mark.anyio
Expand Down Expand Up @@ -1133,13 +1142,13 @@ async def test_cat_endpoints(wallet_environments: WalletTestFramework, wallet_ty
"cat1",
)

cat_0_id = env_0.wallet_aliases["cat0"]
cat_0_id = uint32(env_0.wallet_aliases["cat0"])
# The RPC response contains more than just the balance info but all the
# balance info should match. We're leveraging the `<=` operator to check
# for subset on `dict` `.items()`.
assert (
env_0.wallet_states[uint32(env_0.wallet_aliases["cat0"])].balance.to_json_dict().items()
<= (await env_0.rpc_client.get_wallet_balance(cat_0_id)).items()
<= (await env_0.rpc_client.get_wallet_balance(GetWalletBalance(cat_0_id))).wallet_balance.to_json_dict().items()
)
asset_id = await env_0.rpc_client.get_cat_asset_id(cat_0_id)
assert (await env_0.rpc_client.get_cat_name(cat_0_id)) == wallet_type.default_wallet_name_for_unknown_cat(
Expand Down Expand Up @@ -2807,15 +2816,15 @@ async def test_get_balances(wallet_rpc_environment: WalletRpcTestEnvironment) ->
await time_out_assert(5, check_mempool_spend_count, True, full_node_api, 2)
await farm_transaction_block(full_node_api, wallet_node)
await time_out_assert(20, check_client_synced, True, client)
bal = await client.get_wallet_balances()
assert len(bal) == 3
assert bal["1"]["confirmed_wallet_balance"] == 1999999999880
assert bal["2"]["confirmed_wallet_balance"] == 100
assert bal["3"]["confirmed_wallet_balance"] == 20
bal_ids = await client.get_wallet_balances([3, 2])
assert len(bal_ids) == 2
assert bal["2"]["confirmed_wallet_balance"] == 100
assert bal["3"]["confirmed_wallet_balance"] == 20
bals_response = await client.get_wallet_balances(GetWalletBalances())
assert len(bals_response.wallet_balances) == 3
assert bals_response.wallet_balances[uint32(1)].confirmed_wallet_balance == 1999999999880
assert bals_response.wallet_balances[uint32(2)].confirmed_wallet_balance == 100
assert bals_response.wallet_balances[uint32(3)].confirmed_wallet_balance == 20
bals_response = await client.get_wallet_balances(GetWalletBalances([uint32(3), uint32(2)]))
assert len(bals_response.wallet_balances) == 2
assert bals_response.wallet_balances[uint32(2)].confirmed_wallet_balance == 100
assert bals_response.wallet_balances[uint32(3)].confirmed_wallet_balance == 20


@pytest.mark.parametrize(
Expand Down
5 changes: 3 additions & 2 deletions chia/cmds/plotnft_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG
from chia.wallet.util.wallet_types import WalletType
from chia.wallet.wallet_request_types import (
GetWalletBalance,
GetWallets,
PWAbsorbRewards,
PWJoinPool,
Expand Down Expand Up @@ -161,8 +162,8 @@ async def pprint_pool_wallet_state(
print(f"Target state: {PoolSingletonState(pool_wallet_info.target.state).name}")
print(f"Target pool URL: {pool_wallet_info.target.pool_url}")
if pool_wallet_info.current.state == PoolSingletonState.SELF_POOLING.value:
balances: dict[str, Any] = await wallet_client.get_wallet_balance(wallet_id)
balance = balances["confirmed_wallet_balance"]
balances = (await wallet_client.get_wallet_balance(GetWalletBalance(uint32(wallet_id)))).wallet_balance
balance = balances.confirmed_wallet_balance
typ = WalletType(int(WalletType.POOLING_WALLET))
address_prefix, scale = wallet_coin_unit(typ, address_prefix)
print(f"Claimable balance: {print_balance(balance, scale, address_prefix)}")
Expand Down
Loading
Loading