Skip to content

Commit 018c3fe

Browse files
authored
[CHIA-3293] Port get_wallets to @marshal decorator (#19770)
* Port `get_wallets` * Suggestion by @altendky
1 parent 4ba3f00 commit 018c3fe

File tree

12 files changed

+197
-140
lines changed

12 files changed

+197
-140
lines changed

chia/_tests/cmds/cmd_test_utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from contextlib import asynccontextmanager
66
from dataclasses import dataclass, field
77
from pathlib import Path
8-
from typing import Any, Optional, Union, cast
8+
from typing import Any, Optional, cast
99

1010
from chia_rs import BlockRecord, Coin, G2Element
1111
from chia_rs.sized_bytes import bytes32
@@ -35,11 +35,14 @@
3535
from chia.wallet.util.wallet_types import WalletType
3636
from chia.wallet.wallet_request_types import (
3737
GetSyncStatusResponse,
38+
GetWallets,
39+
GetWalletsResponse,
3840
NFTCalculateRoyalties,
3941
NFTCalculateRoyaltiesResponse,
4042
NFTGetInfo,
4143
NFTGetInfoResponse,
4244
SendTransactionMultiResponse,
45+
WalletInfoResponse,
4346
)
4447
from chia.wallet.wallet_rpc_client import WalletRpcClient
4548
from chia.wallet.wallet_spend_bundle import WalletSpendBundle
@@ -93,11 +96,11 @@ async def get_sync_status(self) -> GetSyncStatusResponse:
9396
self.add_to_log("get_sync_status", ())
9497
return GetSyncStatusResponse(synced=True, syncing=False)
9598

96-
async def get_wallets(self, wallet_type: Optional[WalletType] = None) -> list[dict[str, Union[str, int]]]:
97-
self.add_to_log("get_wallets", (wallet_type,))
99+
async def get_wallets(self, request: GetWallets) -> GetWalletsResponse:
100+
self.add_to_log("get_wallets", (request,))
98101
# we cant start with zero because ints cant have a leading zero
99-
if wallet_type is not None:
100-
w_type = wallet_type
102+
if request.type is not None:
103+
w_type = WalletType(request.type)
101104
elif str(self.fingerprint).startswith(str(WalletType.STANDARD_WALLET.value + 1)):
102105
w_type = WalletType.STANDARD_WALLET
103106
elif str(self.fingerprint).startswith(str(WalletType.CAT.value + 1)):
@@ -110,7 +113,7 @@ async def get_wallets(self, wallet_type: Optional[WalletType] = None) -> list[di
110113
w_type = WalletType.POOLING_WALLET
111114
else:
112115
raise ValueError(f"Invalid fingerprint: {self.fingerprint}")
113-
return [{"id": 1, "type": w_type}]
116+
return GetWalletsResponse([WalletInfoResponse(id=uint32(1), name="", type=uint8(w_type.value), data="")])
114117

115118
async def get_transaction(self, transaction_id: bytes32) -> TransactionRecord:
116119
self.add_to_log("get_transaction", (transaction_id,))

chia/_tests/cmds/wallet/test_vcs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from chia.wallet.vc_wallet.vc_drivers import VCLineageProof, VerifiedCredential
1717
from chia.wallet.vc_wallet.vc_store import VCRecord
1818
from chia.wallet.wallet_request_types import (
19+
GetWallets,
1920
VCAddProofs,
2021
VCGet,
2122
VCGetList,
@@ -395,6 +396,6 @@ async def crcat_approve_pending(
395396
test_condition_valid_times,
396397
)
397398
],
398-
"get_wallets": [(None,)],
399+
"get_wallets": [(GetWallets(type=None, include_data=True),)],
399400
}
400401
test_rpc_clients.wallet_rpc_client.check_log(expected_calls)

chia/_tests/cmds/wallet/test_wallet.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,15 @@
4646
CreateOfferForIDsResponse,
4747
FungibleAsset,
4848
GetHeightInfoResponse,
49+
GetWallets,
50+
GetWalletsResponse,
4951
NFTCalculateRoyalties,
5052
NFTGetWalletDID,
5153
NFTGetWalletDIDResponse,
5254
RoyaltyAsset,
5355
SendTransactionResponse,
5456
TakeOfferResponse,
57+
WalletInfoResponse,
5558
)
5659
from chia.wallet.wallet_spend_bundle import WalletSpendBundle
5760

@@ -91,7 +94,7 @@ def test_get_transaction(capsys: object, get_test_cli_clients: tuple[TestRpcClie
9194
run_cli_command_and_assert(capsys, root_dir, [*command_args, CAT_FINGERPRINT_ARG], cat_assert_list)
9295
# these are various things that should be in the output
9396
expected_calls: logType = {
94-
"get_wallets": [(None,), (None,), (None,)],
97+
"get_wallets": [(GetWallets(type=None, include_data=True),)] * 3,
9598
"get_cat_name": [(1,)],
9699
"get_transaction": [
97100
(bytes32.from_hexstr(bytes32_hexstr),),
@@ -192,7 +195,7 @@ async def get_coin_records(self, request: GetCoinRecords) -> dict[str, Any]:
192195
# these are various things that should be in the output
193196
expected_coin_id = Coin(get_bytes32(4), get_bytes32(5), uint64(12345678)).name()
194197
expected_calls: logType = {
195-
"get_wallets": [(None,), (None,)],
198+
"get_wallets": [(GetWallets(type=None, include_data=True),)] * 2,
196199
"get_transactions": [
197200
(1, 2, 4, SortKey.RELEVANCE, True, None, None, None),
198201
(1, 2, 4, SortKey.RELEVANCE, True, None, None, None),
@@ -210,28 +213,30 @@ def test_show(capsys: object, get_test_cli_clients: tuple[TestRpcClients, Path])
210213

211214
# set RPC Client
212215
class ShowRpcClient(TestWalletRpcClient):
213-
async def get_wallets(self, wallet_type: Optional[WalletType] = None) -> list[dict[str, Union[str, int]]]:
214-
self.add_to_log("get_wallets", (wallet_type,))
215-
wallet_list: list[dict[str, Union[str, int]]] = [
216-
{"data": "", "id": 1, "name": "Chia Wallet", "type": WalletType.STANDARD_WALLET},
217-
{
218-
"data": "dc59bcd60ce5fc9c93a5d3b11875486b03efb53a53da61e453f5cf61a774686001ff02ffff01ff02ffff03ff2f"
216+
async def get_wallets(self, request: GetWallets) -> GetWalletsResponse:
217+
self.add_to_log("get_wallets", (request,))
218+
wallet_list: list[WalletInfoResponse] = [
219+
WalletInfoResponse(
220+
data="", id=uint32(1), name="Chia Wallet", type=uint8(WalletType.STANDARD_WALLET.value)
221+
),
222+
WalletInfoResponse(
223+
data="dc59bcd60ce5fc9c93a5d3b11875486b03efb53a53da61e453f5cf61a774686001ff02ffff01ff02ffff03ff2f"
219224
"ffff01ff0880ffff01ff02ffff03ffff09ff2dff0280ff80ffff01ff088080ff018080ff0180ffff04ffff01a09848f0ef"
220225
"6587565c48ee225cc837abbe406b91946c938e1739da49fc26c04286ff018080",
221-
"id": 2,
222-
"name": "test2",
223-
"type": WalletType.CAT,
224-
},
225-
{
226-
"data": '{"did_id": "0xcee228b8638c67cb66a55085be99fa3b457ae5b56915896f581990f600b2c652"}',
227-
"id": 3,
228-
"name": "NFT Wallet",
229-
"type": WalletType.NFT,
230-
},
226+
id=uint32(2),
227+
name="test2",
228+
type=uint8(WalletType.CAT.value),
229+
),
230+
WalletInfoResponse(
231+
data='{"did_id": "0xcee228b8638c67cb66a55085be99fa3b457ae5b56915896f581990f600b2c652"}',
232+
id=uint32(3),
233+
name="NFT Wallet",
234+
type=uint8(WalletType.NFT.value),
235+
),
231236
]
232-
if wallet_type is WalletType.CAT:
233-
return [wallet_list[1]]
234-
return wallet_list
237+
if request.type is not None and WalletType(request.type) is WalletType.CAT:
238+
return GetWalletsResponse([wallet_list[1]])
239+
return GetWalletsResponse(wallet_list)
235240

236241
async def get_height_info(self) -> GetHeightInfoResponse:
237242
self.add_to_log("get_height_info", ())
@@ -296,7 +301,10 @@ async def get_connections(
296301
run_cli_command_and_assert(capsys, root_dir, [*command_args, "--wallet_type", "cat"], other_assert_list)
297302
# these are various things that should be in the output
298303
expected_calls: logType = {
299-
"get_wallets": [(None,), (WalletType.CAT,)],
304+
"get_wallets": [
305+
(GetWallets(type=None, include_data=True),),
306+
(GetWallets(type=uint16(WalletType.CAT.value), include_data=True),),
307+
],
300308
"get_sync_status": [(), ()],
301309
"get_height_info": [(), ()],
302310
"get_wallet_balance": [(1,), (2,), (3,), (2,)],
@@ -427,7 +435,7 @@ async def cat_spend(
427435

428436
# these are various things that should be in the output
429437
expected_calls: logType = {
430-
"get_wallets": [(None,), (None,)],
438+
"get_wallets": [(GetWallets(type=None, include_data=True),)] * 2,
431439
"send_transaction": [
432440
(
433441
1,

chia/_tests/pools/test_pool_cmdline.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytest
1010
from chia_rs import G1Element
1111
from chia_rs.sized_bytes import bytes32
12-
from chia_rs.sized_ints import uint32, uint64
12+
from chia_rs.sized_ints import uint16, uint32, uint64
1313

1414
# TODO: update after resolution in https://github.com/pytest-dev/pytest/issues/7469
1515
from pytest_mock import MockerFixture
@@ -46,7 +46,7 @@
4646
from chia.wallet.util.address_type import AddressType
4747
from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG
4848
from chia.wallet.util.wallet_types import WalletType
49-
from chia.wallet.wallet_request_types import PWStatus
49+
from chia.wallet.wallet_request_types import GetWallets, PWStatus
5050
from chia.wallet.wallet_rpc_client import WalletRpcClient
5151
from chia.wallet.wallet_state_manager import WalletStateManager
5252

@@ -153,9 +153,9 @@ async def test_plotnft_cli_create(
153153
]
154154
)
155155

156-
summaries_response = await wallet_rpc.get_wallets(WalletType.POOLING_WALLET)
157-
assert len(summaries_response) == 1
158-
wallet_id: int = summaries_response[0]["id"]
156+
summaries_response = await wallet_rpc.get_wallets(GetWallets(type=uint16(WalletType.POOLING_WALLET)))
157+
assert len(summaries_response.wallets) == 1
158+
wallet_id: int = summaries_response.wallets[0].id
159159

160160
await verify_pool_state(wallet_rpc, wallet_id, PoolSingletonState.SELF_POOLING)
161161

0 commit comments

Comments
 (0)