Skip to content
Closed

touchup #19861

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions chia/_tests/cmds/wallet/test_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
from chia_rs import Coin, G2Element
from chia_rs.sized_bytes import bytes32
from chia_rs.sized_ints import uint8, uint16, uint32, uint64
from chia_rs.sized_ints import uint8, uint16, uint32, uint64, uint128
from click.testing import CliRunner

from chia._tests.cmds.cmd_test_utils import TestRpcClients, TestWalletRpcClient, logType, run_cli_command_and_assert
Expand Down Expand Up @@ -41,11 +41,14 @@
from chia.wallet.util.wallet_types import WalletType
from chia.wallet.wallet_coin_store import GetCoinRecords
from chia.wallet.wallet_request_types import (
BalanceResponse,
CancelOfferResponse,
CATSpendResponse,
CreateOfferForIDsResponse,
FungibleAsset,
GetHeightInfoResponse,
GetWalletBalance,
GetWalletBalanceResponse,
GetWallets,
GetWalletsResponse,
NFTCalculateRoyalties,
Expand Down Expand Up @@ -242,19 +245,27 @@ async def get_height_info(self) -> GetHeightInfoResponse:
self.add_to_log("get_height_info", ())
return GetHeightInfoResponse(uint32(10))

async def get_wallet_balance(self, wallet_id: int) -> dict[str, uint64]:
self.add_to_log("get_wallet_balance", (wallet_id,))
if wallet_id == 1:
amount = uint64(1000000000)
elif wallet_id == 2:
amount = uint64(2000000000)
async def get_wallet_balance(self, request: GetWalletBalance) -> GetWalletBalanceResponse:
self.add_to_log("get_wallet_balance", (request,))
if request.wallet_id == 1:
amount = uint128(1000000000)
elif request.wallet_id == 2:
amount = uint128(2000000000)
else:
amount = uint64(1)
return {
"confirmed_wallet_balance": amount,
"spendable_balance": amount,
"unconfirmed_wallet_balance": uint64(0),
}
amount = uint128(1)
return GetWalletBalanceResponse(
BalanceResponse(
wallet_id=request.wallet_id,
wallet_type=uint8(0), # Doesn't matter
confirmed_wallet_balance=amount,
spendable_balance=amount,
max_send_amount=uint128(0),
unspent_coin_count=uint32(0),
unconfirmed_wallet_balance=uint128(0),
pending_change=uint64(0),
pending_coin_removal_count=uint32(0),
)
)

async def get_nft_wallet_did(self, request: NFTGetWalletDID) -> NFTGetWalletDIDResponse:
self.add_to_log("get_nft_wallet_did", (request.wallet_id,))
Expand Down Expand Up @@ -307,7 +318,12 @@ async def get_connections(
],
"get_sync_status": [(), ()],
"get_height_info": [(), ()],
"get_wallet_balance": [(1,), (2,), (3,), (2,)],
"get_wallet_balance": [
(GetWalletBalance(wallet_id=uint32(1)),),
(GetWalletBalance(wallet_id=uint32(2)),),
(GetWalletBalance(wallet_id=uint32(3)),),
(GetWalletBalance(wallet_id=uint32(2)),),
],
"get_nft_wallet_did": [(3,)],
"get_connections": [(None,), (None,)],
}
Expand Down
30 changes: 17 additions & 13 deletions chia/_tests/environments/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import operator
import unittest
from collections.abc import Iterator
from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass, field, replace
from typing import TYPE_CHECKING, Any, ClassVar, Union, cast

from chia_rs.sized_bytes import bytes32
Expand All @@ -25,6 +25,7 @@
from chia.wallet.wallet import Wallet
from chia.wallet.wallet_node import Balance, WalletNode
from chia.wallet.wallet_node_api import WalletNodeAPI
from chia.wallet.wallet_request_types import GetWalletBalance
from chia.wallet.wallet_rpc_api import WalletRpcApi
from chia.wallet.wallet_rpc_client import WalletRpcClient
from chia.wallet.wallet_state_manager import WalletStateManager
Expand Down Expand Up @@ -169,7 +170,9 @@ async def check_balances(self, additional_balance_info: dict[Union[int, str], di
else {}
),
}
balance_response: dict[str, int] = await self.rpc_client.get_wallet_balance(wallet_id)
balance_response: dict[str, int] = (
await self.rpc_client.get_wallet_balance(GetWalletBalance(wallet_id))
).wallet_balance.to_json_dict()

if not expected_result.items() <= balance_response.items():
for key, value in expected_result.items():
Expand Down Expand Up @@ -214,7 +217,7 @@ async def change_balances(self, update_dictionary: dict[Union[int, str], dict[st
for wallet_id_or_alias, kwargs in update_dictionary.items():
wallet_id: uint32 = self.dealias_wallet_id(wallet_id_or_alias)

new_values: dict[str, int] = {}
new_values: dict[str, object] = {}
existing_values: Balance = await self.node.get_balance(wallet_id)
if kwargs.get("init", False):
new_values = {k: v for k, v in kwargs.items() if k not in {"set_remainder", "init"}}
Expand Down Expand Up @@ -244,21 +247,22 @@ async def change_balances(self, update_dictionary: dict[Union[int, str], dict[st
else:
new_values[key] = getattr(self.wallet_states[wallet_id].balance, key) + change

if kwargs.get("set_remainder", False):
new_balance = existing_values
elif kwargs.get("init"):
new_balance = Balance.create_empty()
else:
new_balance = self.wallet_states[wallet_id].balance

# retaining the untyped nature of the new values (for now...)
new_balance = replace(new_balance, **new_values) # type: ignore[arg-type]

self.wallet_states = {
**self.wallet_states,
wallet_id: WalletState(
**{
**({} if kwargs.get("init", False) else asdict(self.wallet_states[wallet_id])),
"balance": Balance(
**{
**(
asdict(existing_values)
if kwargs.get("set_remainder", False)
else ({} if kwargs.get("init") else asdict(self.wallet_states[wallet_id].balance))
),
**new_values,
}
),
"balance": new_balance,
}
),
}
Expand Down
65 changes: 36 additions & 29 deletions chia/_tests/pools/test_pool_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,14 @@
from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG
from chia.wallet.util.wallet_types import WalletType
from chia.wallet.wallet_node import WalletNode
from chia.wallet.wallet_request_types import GetWallets, PWAbsorbRewards, PWJoinPool, PWSelfPool, PWStatus
from chia.wallet.wallet_request_types import (
GetWalletBalance,
GetWallets,
PWAbsorbRewards,
PWJoinPool,
PWSelfPool,
PWStatus,
)
from chia.wallet.wallet_rpc_client import WalletRpcClient
from chia.wallet.wallet_state_manager import WalletStateManager

Expand Down Expand Up @@ -468,8 +475,8 @@ def mempool_empty() -> bool:
assert len(asset_id) > 0
await full_node_api.process_all_wallet_transactions(wallet=wallet)
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
bal_0 = await client.get_wallet_balance(cat_0_id)
assert bal_0["confirmed_wallet_balance"] == 20
bal_0 = (await client.get_wallet_balance(GetWalletBalance(cat_0_id))).wallet_balance
assert bal_0.confirmed_wallet_balance == 20

# Test creation of many pool wallets. Use untrusted since that is the more complicated protocol, but don't
# run this code more than once, since it's slow.
Expand Down Expand Up @@ -535,8 +542,8 @@ async def test_absorb_self(
await add_blocks_in_batches(blocks[-3:], full_node_api.full_node)
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)

bal = await client.get_wallet_balance(2)
assert bal["confirmed_wallet_balance"] == 2 * 1_750_000_000_000
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal.confirmed_wallet_balance == 2 * 1_750_000_000_000

# Claim 2 * 1.75, and farm a new 1.75
absorb_txs = (
Expand All @@ -561,8 +568,8 @@ async def test_absorb_self(
new_status: PoolWalletInfo = (await client.pw_status(PWStatus(uint32(2)))).state
assert status.current == new_status.current
assert status.tip_singleton_coin_id != new_status.tip_singleton_coin_id
bal = await client.get_wallet_balance(2)
assert bal["confirmed_wallet_balance"] == 1 * 1_750_000_000_000
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal.confirmed_wallet_balance == 1 * 1_750_000_000_000

# Claim another 1.75
absorb_txs1 = (
Expand All @@ -575,8 +582,8 @@ async def test_absorb_self(

await full_node_api.farm_blocks_to_puzzlehash(count=2, farm_to=our_ph, guarantee_transaction_blocks=True)
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
bal = await client.get_wallet_balance(2)
assert bal["confirmed_wallet_balance"] == 0
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal.confirmed_wallet_balance == 0

assert len(await wallet_node.wallet_state_manager.tx_store.get_unconfirmed_for_wallet(2)) == 0

Expand All @@ -590,8 +597,8 @@ async def test_absorb_self(
await full_node_api.farm_blocks_to_puzzlehash(count=2, farm_to=our_ph, guarantee_transaction_blocks=True)

# Balance ignores non coinbase TX
bal = await client.get_wallet_balance(2)
assert bal["confirmed_wallet_balance"] == 0
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal.confirmed_wallet_balance == 0

with pytest.raises(ValueError):
await client.pw_absorb_rewards(
Expand Down Expand Up @@ -626,8 +633,8 @@ async def test_absorb_self_multiple_coins(
pool_expected_confirmed_balance = 0

await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
main_bal = await client.get_wallet_balance(1)
assert main_bal["confirmed_wallet_balance"] == main_expected_confirmed_balance
main_bal = (await client.get_wallet_balance(GetWalletBalance(uint32(1)))).wallet_balance
assert main_bal.confirmed_wallet_balance == main_expected_confirmed_balance

status: PoolWalletInfo = (await client.pw_status(PWStatus(uint32(2)))).state
assert status.current.state == PoolSingletonState.SELF_POOLING.value
Expand All @@ -650,10 +657,10 @@ async def test_absorb_self_multiple_coins(
pool_expected_confirmed_balance += block_count * 1_750_000_000_000
main_expected_confirmed_balance += block_count * 250_000_000_000

main_bal = await client.get_wallet_balance(1)
assert main_bal["confirmed_wallet_balance"] == main_expected_confirmed_balance
bal = await client.get_wallet_balance(2)
assert bal["confirmed_wallet_balance"] == pool_expected_confirmed_balance
main_bal = (await client.get_wallet_balance(GetWalletBalance(uint32(1)))).wallet_balance
assert main_bal.confirmed_wallet_balance == main_expected_confirmed_balance
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal.confirmed_wallet_balance == pool_expected_confirmed_balance

# Claim
absorb_txs = (
Expand All @@ -671,10 +678,10 @@ async def test_absorb_self_multiple_coins(
new_status: PoolWalletInfo = (await client.pw_status(PWStatus(uint32(2)))).state
assert status.current == new_status.current
assert status.tip_singleton_coin_id != new_status.tip_singleton_coin_id
main_bal = await client.get_wallet_balance(1)
pool_bal = await client.get_wallet_balance(2)
assert pool_bal["confirmed_wallet_balance"] == pool_expected_confirmed_balance
assert main_bal["confirmed_wallet_balance"] == main_expected_confirmed_balance # 10499999999999
main_bal = (await client.get_wallet_balance(GetWalletBalance(uint32(1)))).wallet_balance
pool_bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert pool_bal.confirmed_wallet_balance == pool_expected_confirmed_balance
assert main_bal.confirmed_wallet_balance == main_expected_confirmed_balance # 10499999999999

@pytest.mark.anyio
async def test_absorb_pooling(
Expand Down Expand Up @@ -726,8 +733,8 @@ async def farming_to_pool() -> bool:
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
# Pooled plots don't have balance
main_expected_confirmed_balance += block_count * 250_000_000_000
bal = await client.get_wallet_balance(2)
assert bal["confirmed_wallet_balance"] == 0
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal.confirmed_wallet_balance == 0

# Claim block_count * 1.75
ret = await client.pw_absorb_rewards(
Expand All @@ -751,12 +758,12 @@ async def status_updated() -> bool:

await time_out_assert(20, status_updated)
new_status = (await client.pw_status(PWStatus(uint32(2)))).state
bal = await client.get_wallet_balance(2)
assert bal["confirmed_wallet_balance"] == 0
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal.confirmed_wallet_balance == 0

await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
bal = await client.get_wallet_balance(2)
assert bal["confirmed_wallet_balance"] == 0
bal = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal.confirmed_wallet_balance == 0
assert len(await wallet_node.wallet_state_manager.tx_store.get_unconfirmed_for_wallet(2)) == 0
peak = full_node_api.full_node.blockchain.get_peak()
assert peak is not None
Expand Down Expand Up @@ -798,8 +805,8 @@ async def status_updated() -> bool:
status = (await client.pw_status(PWStatus(uint32(2)))).state
assert ret.fee_transaction is None

bal2 = await client.get_wallet_balance(2)
assert bal2["confirmed_wallet_balance"] == 0
bal2 = (await client.get_wallet_balance(GetWalletBalance(uint32(2)))).wallet_balance
assert bal2.confirmed_wallet_balance == 0

@pytest.mark.anyio
async def test_self_pooling_to_pooling(self, setup: Setup, fee: uint64, self_hostname: str) -> None:
Expand Down
Loading
Loading