diff --git a/chia/_tests/cmds/cmd_test_utils.py b/chia/_tests/cmds/cmd_test_utils.py index b0be04d1c1b8..8666076e9304 100644 --- a/chia/_tests/cmds/cmd_test_utils.py +++ b/chia/_tests/cmds/cmd_test_utils.py @@ -35,6 +35,8 @@ from chia.wallet.util.wallet_types import WalletType from chia.wallet.wallet_request_types import ( GetSyncStatusResponse, + GetTransaction, + GetTransactionResponse, GetWallets, GetWalletsResponse, NFTCalculateRoyalties, @@ -115,27 +117,30 @@ async def get_wallets(self, request: GetWallets) -> GetWalletsResponse: raise ValueError(f"Invalid fingerprint: {self.fingerprint}") return GetWalletsResponse([WalletInfoResponse(id=uint32(1), name="", type=uint8(w_type.value), data="")]) - async def get_transaction(self, transaction_id: bytes32) -> TransactionRecord: - self.add_to_log("get_transaction", (transaction_id,)) - return TransactionRecord( - confirmed_at_height=uint32(1), - created_at_time=uint64(1234), - to_puzzle_hash=bytes32([1] * 32), - to_address=encode_puzzle_hash(bytes32([1] * 32), "xch"), - amount=uint64(12345678), - fee_amount=uint64(1234567), - confirmed=False, - sent=uint32(0), - spend_bundle=WalletSpendBundle([], G2Element()), - additions=[Coin(bytes32([1] * 32), bytes32([2] * 32), uint64(12345678))], - removals=[Coin(bytes32([2] * 32), bytes32([4] * 32), uint64(12345678))], - wallet_id=uint32(1), - sent_to=[("aaaaa", uint8(1), None)], - trade_id=None, - type=uint32(TransactionType.OUTGOING_TX.value), - name=bytes32([2] * 32), - memos={bytes32([3] * 32): [bytes([4] * 32)]}, - valid_times=ConditionValidTimes(), + async def get_transaction(self, request: GetTransaction) -> GetTransactionResponse: + self.add_to_log("get_transaction", (request,)) + return GetTransactionResponse( + TransactionRecord( + confirmed_at_height=uint32(1), + created_at_time=uint64(1234), + to_puzzle_hash=bytes32([1] * 32), + to_address=encode_puzzle_hash(bytes32([1] * 32), "xch"), + amount=uint64(12345678), + fee_amount=uint64(1234567), + confirmed=False, + sent=uint32(0), + spend_bundle=WalletSpendBundle([], G2Element()), + additions=[Coin(bytes32([1] * 32), bytes32([2] * 32), uint64(12345678))], + removals=[Coin(bytes32([2] * 32), bytes32([4] * 32), uint64(12345678))], + wallet_id=uint32(1), + sent_to=[("aaaaa", uint8(1), None)], + trade_id=None, + type=uint32(TransactionType.OUTGOING_TX.value), + name=bytes32([2] * 32), + memos={bytes32([3] * 32): [bytes([4] * 32)]}, + valid_times=ConditionValidTimes(), + ), + bytes32([2] * 32), ) async def get_cat_name(self, wallet_id: int) -> str: diff --git a/chia/_tests/cmds/wallet/test_wallet.py b/chia/_tests/cmds/wallet/test_wallet.py index 565f6fbf2e8e..f0a384196c8d 100644 --- a/chia/_tests/cmds/wallet/test_wallet.py +++ b/chia/_tests/cmds/wallet/test_wallet.py @@ -35,7 +35,7 @@ from chia.wallet.trading.trade_status import TradeStatus from chia.wallet.transaction_record import TransactionRecord from chia.wallet.transaction_sorting import SortKey -from chia.wallet.util.query_filter import HashFilter, TransactionTypeFilter +from chia.wallet.util.query_filter import HashFilter from chia.wallet.util.transaction_type import TransactionType from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG, TXConfig from chia.wallet.util.wallet_types import WalletType @@ -47,6 +47,9 @@ CreateOfferForIDsResponse, FungibleAsset, GetHeightInfoResponse, + GetTransaction, + GetTransactions, + GetTransactionsResponse, GetWalletBalance, GetWalletBalanceResponse, GetWallets, @@ -57,6 +60,7 @@ RoyaltyAsset, SendTransactionResponse, TakeOfferResponse, + TransactionRecordWithMetadata, WalletInfoResponse, ) from chia.wallet.wallet_spend_bundle import WalletSpendBundle @@ -100,9 +104,9 @@ def test_get_transaction(capsys: object, get_test_cli_clients: tuple[TestRpcClie "get_wallets": [(GetWallets(type=None, include_data=True),)] * 3, "get_cat_name": [(1,)], "get_transaction": [ - (bytes32.from_hexstr(bytes32_hexstr),), - (bytes32.from_hexstr(bytes32_hexstr),), - (bytes32.from_hexstr(bytes32_hexstr),), + (GetTransaction(bytes32.from_hexstr(bytes32_hexstr)),), + (GetTransaction(bytes32.from_hexstr(bytes32_hexstr)),), + (GetTransaction(bytes32.from_hexstr(bytes32_hexstr)),), ], } test_rpc_clients.wallet_rpc_client.check_log(expected_calls) @@ -113,24 +117,13 @@ def test_get_transactions(capsys: object, get_test_cli_clients: tuple[TestRpcCli # set RPC Client class GetTransactionsWalletRpcClient(TestWalletRpcClient): - async def get_transactions( - self, - wallet_id: int, - start: int, - end: int, - sort_key: Optional[SortKey] = None, - reverse: bool = False, - to_address: Optional[str] = None, - type_filter: Optional[TransactionTypeFilter] = None, - confirmed: Optional[bool] = None, - ) -> list[TransactionRecord]: - self.add_to_log( - "get_transactions", (wallet_id, start, end, sort_key, reverse, to_address, type_filter, confirmed) - ) + async def get_transactions(self, request: GetTransactions) -> GetTransactionsResponse: + self.add_to_log("get_transactions", (request,)) l_tx_rec = [] - for i in range(start, end): - t_type = TransactionType.INCOMING_CLAWBACK_SEND if i == end - 1 else TransactionType.INCOMING_TX - tx_rec = TransactionRecord( + assert request.start is not None and request.end is not None + for i in range(request.start, request.end): + t_type = TransactionType.INCOMING_CLAWBACK_SEND if i == request.end - 1 else TransactionType.INCOMING_TX + tx_rec = TransactionRecordWithMetadata( confirmed_at_height=uint32(1 + i), created_at_time=uint64(1234 + i), to_puzzle_hash=bytes32([1 + i] * 32), @@ -152,7 +145,7 @@ async def get_transactions( ) l_tx_rec.append(tx_rec) - return l_tx_rec + return GetTransactionsResponse(l_tx_rec, request.wallet_id) async def get_coin_records(self, request: GetCoinRecords) -> dict[str, Any]: self.add_to_log("get_coin_records", (request,)) @@ -201,8 +194,8 @@ async def get_coin_records(self, request: GetCoinRecords) -> dict[str, Any]: expected_calls: logType = { "get_wallets": [(GetWallets(type=None, include_data=True),)] * 2, "get_transactions": [ - (1, 2, 4, SortKey.RELEVANCE, True, None, None, None), - (1, 2, 4, SortKey.RELEVANCE, True, None, None, None), + (GetTransactions(uint32(1), uint16(2), uint16(4), SortKey.RELEVANCE.name, True, None, None, None),), + (GetTransactions(uint32(1), uint16(2), uint16(4), SortKey.RELEVANCE.name, True, None, None, None),), ], "get_coin_records": [ (GetCoinRecords(coin_id_filter=HashFilter.include([expected_coin_id])),), @@ -490,7 +483,7 @@ async def cat_spend( test_condition_valid_times, ) ], - "get_transaction": [(get_bytes32(2),), (get_bytes32(2),)], + "get_transaction": [(GetTransaction(get_bytes32(2)),), (GetTransaction(get_bytes32(2)),)], } test_rpc_clients.wallet_rpc_client.check_log(expected_calls) diff --git a/chia/_tests/pools/test_pool_rpc.py b/chia/_tests/pools/test_pool_rpc.py index 2c26382dd345..750485166622 100644 --- a/chia/_tests/pools/test_pool_rpc.py +++ b/chia/_tests/pools/test_pool_rpc.py @@ -42,6 +42,7 @@ from chia.wallet.util.wallet_types import WalletType from chia.wallet.wallet_node import WalletNode from chia.wallet.wallet_request_types import ( + GetTransactions, GetWalletBalance, GetWallets, PWAbsorbRewards, @@ -605,7 +606,7 @@ async def test_absorb_self( PWAbsorbRewards(wallet_id=uint32(2), fee=uint64(fee), push=True), DEFAULT_TX_CONFIG ) - tx1 = await client.get_transactions(1) + tx1 = (await client.get_transactions(GetTransactions(uint32(1)))).transactions assert (250_000_000_000 + fee) in [tx.amount for tx in tx1] @pytest.mark.anyio diff --git a/chia/_tests/wallet/rpc/test_wallet_rpc.py b/chia/_tests/wallet/rpc/test_wallet_rpc.py index 50eb1c082820..1f5fd3a1b168 100644 --- a/chia/_tests/wallet/rpc/test_wallet_rpc.py +++ b/chia/_tests/wallet/rpc/test_wallet_rpc.py @@ -123,6 +123,8 @@ GetPrivateKey, GetSyncStatusResponse, GetTimestampForHeight, + GetTransaction, + GetTransactions, GetWalletBalance, GetWalletBalances, GetWallets, @@ -345,7 +347,7 @@ async def assert_get_balance(rpc_client: WalletRpcClient, wallet_node: WalletNod async def tx_in_mempool(client: WalletRpcClient, transaction_id: bytes32) -> bool: - tx = await client.get_transaction(transaction_id) + tx = (await client.get_transaction(GetTransaction(transaction_id))).transaction return tx.is_in_mempool() @@ -433,7 +435,7 @@ async def test_send_transaction(wallet_rpc_environment: WalletRpcTestEnvironment await farm_transaction(full_node_api, wallet_node, spend_bundle) # Checks that the memo can be retrieved - tx_confirmed = await client.get_transaction(transaction_id) + tx_confirmed = (await client.get_transaction(GetTransaction(transaction_id))).transaction assert tx_confirmed.confirmed assert len(tx_confirmed.memos) == 1 assert [b"this is a basic tx"] in tx_confirmed.memos.values() @@ -479,7 +481,7 @@ async def test_push_transactions(wallet_rpc_environment: WalletRpcTestEnvironmen await farm_transaction(full_node_api, wallet_node, spend_bundle) for tx in resp_client.transactions: - assert (await client.get_transaction(transaction_id=tx.name)).confirmed + assert (await client.get_transaction(GetTransaction(transaction_id=tx.name))).transaction.confirmed # Just testing NOT failure here really (parsing) await client.push_tx(PushTX(spend_bundle)) @@ -973,7 +975,7 @@ async def test_send_transaction_multi(wallet_rpc_environment: WalletRpcTestEnvir await time_out_assert(20, get_confirmed_balance, generated_funds - amount_outputs - amount_fee, client, 1) # Checks that the memo can be retrieved - tx_confirmed = await client.get_transaction(send_tx_res.name) + tx_confirmed = (await client.get_transaction(GetTransaction(send_tx_res.name))).transaction assert tx_confirmed.confirmed memos = tx_confirmed.memos assert len(memos) == len(outputs) @@ -996,18 +998,20 @@ async def test_get_transactions(wallet_rpc_environment: WalletRpcTestEnvironment await generate_funds(full_node_api, env.wallet_1, 5) - all_transactions = await client.get_transactions(1) + all_transactions = (await client.get_transactions(GetTransactions(uint32(1)))).transactions assert len(all_transactions) >= 10 # Test transaction pagination - some_transactions = await client.get_transactions(1, 0, 5) - some_transactions_2 = await client.get_transactions(1, 5, 10) + some_transactions = (await client.get_transactions(GetTransactions(uint32(1), uint16(0), uint16(5)))).transactions + some_transactions_2 = ( + await client.get_transactions(GetTransactions(uint32(1), uint16(5), uint16(10))) + ).transactions assert some_transactions == all_transactions[0:5] assert some_transactions_2 == all_transactions[5:10] # Testing sorts # Test the default sort (CONFIRMED_AT_HEIGHT) assert all_transactions == sorted(all_transactions, key=attrgetter("confirmed_at_height")) - all_transactions = await client.get_transactions(1, reverse=True) + all_transactions = (await client.get_transactions(GetTransactions(uint32(1), reverse=True))).transactions assert all_transactions == sorted(all_transactions, key=attrgetter("confirmed_at_height"), reverse=True) # Test RELEVANCE @@ -1018,13 +1022,20 @@ async def test_get_transactions(wallet_rpc_environment: WalletRpcTestEnvironment 1, uint64(1), encode_puzzle_hash(puzhash, "txch"), DEFAULT_TX_CONFIG ) # Create a pending tx - all_transactions = await client.get_transactions(1, sort_key=SortKey.RELEVANCE) + with pytest.raises(ValueError, match="There is no known sort foo"): + await client.get_transactions(GetTransactions(uint32(1), sort_key="foo")) + + all_transactions = ( + await client.get_transactions(GetTransactions(uint32(1), sort_key=SortKey.RELEVANCE.name)) + ).transactions sorted_transactions = sorted(all_transactions, key=attrgetter("created_at_time"), reverse=True) sorted_transactions = sorted(sorted_transactions, key=attrgetter("confirmed_at_height"), reverse=True) sorted_transactions = sorted(sorted_transactions, key=attrgetter("confirmed")) assert all_transactions == sorted_transactions - all_transactions = await client.get_transactions(1, sort_key=SortKey.RELEVANCE, reverse=True) + all_transactions = ( + await client.get_transactions(GetTransactions(uint32(1), sort_key=SortKey.RELEVANCE.name, reverse=True)) + ).transactions sorted_transactions = sorted(all_transactions, key=attrgetter("created_at_time")) sorted_transactions = sorted(sorted_transactions, key=attrgetter("confirmed_at_height")) sorted_transactions = sorted(sorted_transactions, key=attrgetter("confirmed"), reverse=True) @@ -1036,21 +1047,25 @@ async def test_get_transactions(wallet_rpc_environment: WalletRpcTestEnvironment await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20) await client.send_transaction(1, uint64(1), encode_puzzle_hash(ph_by_addr, "txch"), DEFAULT_TX_CONFIG) await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20) - tx_for_address = await client.get_transactions(1, to_address=encode_puzzle_hash(ph_by_addr, "txch")) + tx_for_address = ( + await client.get_transactions(GetTransactions(uint32(1), to_address=encode_puzzle_hash(ph_by_addr, "txch"))) + ).transactions assert len(tx_for_address) == 1 assert tx_for_address[0].to_puzzle_hash == ph_by_addr # Test type filter - all_transactions = await client.get_transactions( - 1, type_filter=TransactionTypeFilter.include([TransactionType.COINBASE_REWARD]) - ) + all_transactions = ( + await client.get_transactions( + GetTransactions(uint32(1), type_filter=TransactionTypeFilter.include([TransactionType.COINBASE_REWARD])) + ) + ).transactions assert len(all_transactions) == 5 assert all(transaction.type == TransactionType.COINBASE_REWARD.value for transaction in all_transactions) # Test confirmed filter - all_transactions = await client.get_transactions(1, confirmed=True) + all_transactions = (await client.get_transactions(GetTransactions(uint32(1), confirmed=True))).transactions assert len(all_transactions) == 10 assert all(transaction.confirmed for transaction in all_transactions) - all_transactions = await client.get_transactions(1, confirmed=False) + all_transactions = (await client.get_transactions(GetTransactions(uint32(1), confirmed=False))).transactions assert len(all_transactions) == 2 assert all(not transaction.confirmed for transaction in all_transactions) @@ -1058,9 +1073,15 @@ async def test_get_transactions(wallet_rpc_environment: WalletRpcTestEnvironment await wallet.wallet_state_manager.tx_store.add_transaction_record( dataclasses.replace(all_transactions[0], type=uint32(TransactionType.INCOMING_CLAWBACK_SEND)) ) - all_transactions = await client.get_transactions( - 1, type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CLAWBACK_SEND]), confirmed=False - ) + all_transactions = ( + await client.get_transactions( + GetTransactions( + uint32(1), + type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CLAWBACK_SEND]), + confirmed=False, + ) + ) + ).transactions assert len(all_transactions) == 1 @@ -1073,7 +1094,7 @@ async def test_get_transaction_count(wallet_rpc_environment: WalletRpcTestEnviro await generate_funds(full_node_api, env.wallet_1) - all_transactions = await client.get_transactions(1) + all_transactions = (await client.get_transactions(GetTransactions(uint32(1)))).transactions assert len(all_transactions) > 0 transaction_count = await client.get_transaction_count(1) assert transaction_count == len(all_transactions) diff --git a/chia/_tests/wallet/test_wallet.py b/chia/_tests/wallet/test_wallet.py index af8345def433..b3f6161d480f 100644 --- a/chia/_tests/wallet/test_wallet.py +++ b/chia/_tests/wallet/test_wallet.py @@ -395,7 +395,7 @@ async def test_wallet_clawback_clawback(self, wallet_environments: WalletTestFra assert len(txs["transactions"]) == 1 assert not txs["transactions"][0]["confirmed"] assert txs["transactions"][0]["metadata"]["recipient_puzzle_hash"][2:] == normal_puzhash.hex() - assert txs["transactions"][0]["metadata"]["coin_id"] == merkle_coin.name().hex() + assert txs["transactions"][0]["metadata"]["coin_id"] == "0x" + merkle_coin.name().hex() with pytest.raises(ValueError): await api_0.spend_clawback_coins({}) diff --git a/chia/_tests/wallet/vc_wallet/test_vc_wallet.py b/chia/_tests/wallet/vc_wallet/test_vc_wallet.py index 2776f4951dcb..b76cfe3ecab2 100644 --- a/chia/_tests/wallet/vc_wallet/test_vc_wallet.py +++ b/chia/_tests/wallet/vc_wallet/test_vc_wallet.py @@ -7,7 +7,7 @@ import pytest from chia_rs import G2Element from chia_rs.sized_bytes import bytes32 -from chia_rs.sized_ints import uint8, uint16, uint64 +from chia_rs.sized_ints import uint8, uint16, uint32, uint64 from typing_extensions import Literal from chia._tests.environments.wallet import WalletEnvironment, WalletStateTransition, WalletTestFramework @@ -31,6 +31,7 @@ from chia.wallet.wallet import Wallet from chia.wallet.wallet_node import WalletNode from chia.wallet.wallet_request_types import ( + GetTransactions, GetWallets, VCAddProofs, VCGet, @@ -454,13 +455,17 @@ async def test_vc_lifecycle(wallet_environments: WalletTestFramework) -> None: assert await wallet_node_1.wallet_state_manager.wallets[env_1.dealias_wallet_id("crcat")].match_hinted_coin( next(c for tx in txs for c in tx.additions if c.amount == 90), wallet_1_ph ) - pending_tx = await client_1.get_transactions( - env_1.dealias_wallet_id("crcat"), - 0, - 1, - reverse=True, - type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CRCAT_PENDING]), - ) + pending_tx = ( + await client_1.get_transactions( + GetTransactions( + uint32(env_1.dealias_wallet_id("crcat")), + uint16(0), + uint16(1), + reverse=True, + type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CRCAT_PENDING]), + ) + ) + ).transactions assert len(pending_tx) == 1 # Send the VC to wallet_1 to use for the CR-CATs diff --git a/chia/cmds/plotnft_funcs.py b/chia/cmds/plotnft_funcs.py index 09204896bed1..c8d08ffc2392 100644 --- a/chia/cmds/plotnft_funcs.py +++ b/chia/cmds/plotnft_funcs.py @@ -43,6 +43,7 @@ from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG from chia.wallet.util.wallet_types import WalletType from chia.wallet.wallet_request_types import ( + GetTransaction, GetWalletBalance, GetWallets, PWAbsorbRewards, @@ -121,7 +122,7 @@ async def create( start = time.time() while time.time() - start < 10: await asyncio.sleep(0.1) - tx = await wallet_info.client.get_transaction(tx_record.name) + tx = (await wallet_info.client.get_transaction(GetTransaction(tx_record.name))).transaction if len(tx.sent_to) > 0: print(transaction_submitted_msg(tx)) print(transaction_status_msg(wallet_info.fingerprint, tx_record.name)) @@ -286,7 +287,7 @@ async def submit_tx_with_confirmation( continue while time.time() - start < 10: await asyncio.sleep(0.1) - tx = await wallet_client.get_transaction(tx_record.name) + tx = (await wallet_client.get_transaction(GetTransaction(tx_record.name))).transaction if len(tx.sent_to) > 0: print(transaction_submitted_msg(tx)) print(transaction_status_msg(fingerprint, tx_record.name)) diff --git a/chia/cmds/wallet_funcs.py b/chia/cmds/wallet_funcs.py index f0b712a63080..4e6c9ed65268 100644 --- a/chia/cmds/wallet_funcs.py +++ b/chia/cmds/wallet_funcs.py @@ -54,6 +54,8 @@ DIDUpdateMetadata, FungibleAsset, GetNotifications, + GetTransaction, + GetTransactions, GetWalletBalance, GetWallets, NFTAddURI, @@ -182,7 +184,9 @@ async def get_transaction( async with get_wallet_client(root_path, wallet_rpc_port, fingerprint) as (wallet_client, fingerprint, config): transaction_id = bytes32.from_hexstr(tx_id) address_prefix = selected_network_address_prefix(config) - tx: TransactionRecord = await wallet_client.get_transaction(transaction_id=transaction_id) + tx: TransactionRecord = ( + await wallet_client.get_transaction(GetTransaction(transaction_id=transaction_id)) + ).transaction try: wallet_type = await get_wallet_type(wallet_id=tx.wallet_id, wallet_client=wallet_client) @@ -230,9 +234,18 @@ async def get_transactions( [TransactionType.INCOMING_CLAWBACK_RECEIVE, TransactionType.INCOMING_CLAWBACK_SEND] ) ) - txs: list[TransactionRecord] = await wallet_client.get_transactions( - wallet_id, start=offset, end=(offset + limit), sort_key=sort_key, reverse=reverse, type_filter=type_filter - ) + txs = ( + await wallet_client.get_transactions( + GetTransactions( + uint32(wallet_id), + start=uint16(offset), + end=uint16(offset + limit), + sort_key=sort_key.name, + reverse=reverse, + type_filter=type_filter, + ) + ) + ).transactions address_prefix = selected_network_address_prefix(config) if len(txs) == 0: @@ -386,7 +399,7 @@ async def send( start = time.time() while time.time() - start < 10: await asyncio.sleep(0.1) - tx = await wallet_client.get_transaction(tx_id) + tx = (await wallet_client.get_transaction(GetTransaction(tx_id))).transaction if len(tx.sent_to) > 0: print(transaction_submitted_msg(tx)) print(transaction_status_msg(fingerprint, tx_id)) diff --git a/chia/wallet/wallet_request_types.py b/chia/wallet/wallet_request_types.py index 19c667f2af72..2880cda9eb62 100644 --- a/chia/wallet/wallet_request_types.py +++ b/chia/wallet/wallet_request_types.py @@ -2,7 +2,7 @@ import sys from dataclasses import dataclass, field -from typing import Any, Optional, final +from typing import Any, BinaryIO, Optional, final from chia_rs import Coin, G1Element, G2Element, PrivateKey from chia_rs.sized_bytes import bytes32 @@ -28,7 +28,9 @@ from chia.wallet.trade_record import TradeRecord from chia.wallet.trading.offer import Offer from chia.wallet.transaction_record import TransactionRecord +from chia.wallet.transaction_sorting import SortKey from chia.wallet.util.clvm_streamable import json_deserialize_with_clvm_streamable +from chia.wallet.util.query_filter import TransactionTypeFilter from chia.wallet.util.tx_config import TXConfig from chia.wallet.vc_wallet.vc_store import VCProofs, VCRecord from chia.wallet.wallet_info import WalletInfo @@ -257,6 +259,86 @@ class GetWalletBalancesResponse(Streamable): wallet_balances: dict[uint32, BalanceResponse] +@streamable +@dataclass(frozen=True) +class GetTransaction(Streamable): + transaction_id: bytes32 + + +@streamable +@dataclass(frozen=True) +class GetTransactionResponse(Streamable): + transaction: TransactionRecord + transaction_id: bytes32 + + +@streamable +@dataclass(frozen=True) +class GetTransactions(Streamable): + wallet_id: uint32 + start: Optional[uint16] = None + end: Optional[uint16] = None + sort_key: Optional[str] = None + reverse: bool = False + to_address: Optional[str] = None + type_filter: Optional[TransactionTypeFilter] = None + confirmed: Optional[bool] = None + + def __post_init__(self) -> None: + if self.sort_key is not None and not hasattr(SortKey, self.sort_key): + raise ValueError(f"There is no known sort {self.sort_key}") + + +# utility for GetTransactionsResponse +# this class cannot be a dataclass because if it is, streamable will assume it knows how to serialize it +# TODO: We should put some thought into deprecating this and separating the metadata more reasonably +class TransactionRecordMetadata: + content: dict[str, Any] + coin_id: bytes32 + spent: bool + + def __init__(self, content: dict[str, Any], coin_id: bytes32, spent: bool) -> None: + self.content = content + self.coin_id = coin_id + self.spent = spent + + def __bytes__(self) -> bytes: + raise NotImplementedError("Should not be serializing this object as bytes, it's only for RPC") + + @classmethod + def parse(cls, f: BinaryIO) -> TransactionRecordMetadata: + raise NotImplementedError("Should not be deserializing this object from a stream, it's only for RPC") + + def to_json_dict(self) -> dict[str, Any]: + return { + **self.content, + "coin_id": "0x" + self.coin_id.hex(), + "spent": self.spent, + } + + @classmethod + def from_json_dict(cls, json_dict: dict[str, Any]) -> TransactionRecordMetadata: + return TransactionRecordMetadata( + coin_id=bytes32.from_hexstr(json_dict["coin_id"]), + spent=json_dict["spent"], + content={k: v for k, v in json_dict.items() if k not in {"coin_id", "spent"}}, + ) + + +# utility for GetTransactionsResponse +@streamable +@dataclass(frozen=True) +class TransactionRecordWithMetadata(TransactionRecord): + metadata: Optional[TransactionRecordMetadata] = None + + +@streamable +@dataclass(frozen=True) +class GetTransactionsResponse(Streamable): + transactions: list[TransactionRecordWithMetadata] + wallet_id: uint32 + + @streamable @dataclass(frozen=True) class GetNotifications(Streamable): diff --git a/chia/wallet/wallet_rpc_api.py b/chia/wallet/wallet_rpc_api.py index 85b90a2aa0e3..6324025ad592 100644 --- a/chia/wallet/wallet_rpc_api.py +++ b/chia/wallet/wallet_rpc_api.py @@ -179,6 +179,10 @@ GetSyncStatusResponse, GetTimestampForHeight, GetTimestampForHeightResponse, + GetTransaction, + GetTransactionResponse, + GetTransactions, + GetTransactionsResponse, GetWalletBalance, GetWalletBalanceResponse, GetWalletBalances, @@ -232,6 +236,7 @@ SplitCoinsResponse, SubmitTransactions, SubmitTransactionsResponse, + TransactionRecordWithMetadata, VCAddProofs, VCGet, VCGetList, @@ -1299,16 +1304,18 @@ async def get_wallet_balances(self, request: GetWalletBalances) -> GetWalletBala {wallet_id: await self._get_wallet_balance(wallet_id) for wallet_id in wallet_ids} ) - async def get_transaction(self, request: dict[str, Any]) -> EndpointResult: - transaction_id: bytes32 = bytes32.from_hexstr(request["transaction_id"]) - tr: Optional[TransactionRecord] = await self.service.wallet_state_manager.get_transaction(transaction_id) + @marshal + async def get_transaction(self, request: GetTransaction) -> GetTransactionResponse: + tr: Optional[TransactionRecord] = await self.service.wallet_state_manager.get_transaction( + request.transaction_id + ) if tr is None: - raise ValueError(f"Transaction 0x{transaction_id.hex()} not found") + raise ValueError(f"Transaction 0x{request.transaction_id.hex()} not found") - return { - "transaction": (await self._convert_tx_puzzle_hash(tr)).to_json_dict(), - "transaction_id": tr.name, - } + return GetTransactionResponse( + await self._convert_tx_puzzle_hash(tr), + tr.name, + ) async def get_transaction_memo(self, request: dict[str, Any]) -> EndpointResult: transaction_id: bytes32 = bytes32.from_hexstr(request["transaction_id"]) @@ -1491,31 +1498,21 @@ async def combine_coins( return CombineCoinsResponse([], []) # tx_endpoint will take care to fill this out - async def get_transactions(self, request: dict[str, Any]) -> EndpointResult: - wallet_id = int(request["wallet_id"]) - - start = request.get("start", 0) - end = request.get("end", 50) - sort_key = request.get("sort_key", None) - reverse = request.get("reverse", False) - - to_address = request.get("to_address", None) + @marshal + async def get_transactions(self, request: GetTransactions) -> GetTransactionsResponse: to_puzzle_hash: Optional[bytes32] = None - if to_address is not None: - to_puzzle_hash = decode_puzzle_hash(to_address) - type_filter = None - if "type_filter" in request: - type_filter = TransactionTypeFilter.from_json_dict(request["type_filter"]) + if request.to_address is not None: + to_puzzle_hash = decode_puzzle_hash(request.to_address) transactions = await self.service.wallet_state_manager.tx_store.get_transactions_between( - wallet_id, - start, - end, - sort_key=sort_key, - reverse=reverse, + wallet_id=request.wallet_id, + start=uint16(0) if request.start is None else request.start, + end=uint16(50) if request.end is None else request.end, + sort_key=request.sort_key, + reverse=request.reverse, to_puzzle_hash=to_puzzle_hash, - type_filter=type_filter, - confirmed=request.get("confirmed", None), + type_filter=request.type_filter, + confirmed=request.confirmed, ) tx_list = [] # Format for clawback transactions @@ -1538,10 +1535,10 @@ async def get_transactions(self, request: dict[str, Any]) -> EndpointResult: continue tx["metadata"]["coin_id"] = coin.name().hex() tx["metadata"]["spent"] = record.spent - return { - "transactions": tx_list, - "wallet_id": wallet_id, - } + return GetTransactionsResponse( + transactions=[TransactionRecordWithMetadata.from_json_dict(tx) for tx in tx_list], + wallet_id=request.wallet_id, + ) async def get_transaction_count(self, request: dict[str, Any]) -> EndpointResult: wallet_id = int(request["wallet_id"]) diff --git a/chia/wallet/wallet_rpc_client.py b/chia/wallet/wallet_rpc_client.py index 537eb1ad8ab5..d884302523ed 100644 --- a/chia/wallet/wallet_rpc_client.py +++ b/chia/wallet/wallet_rpc_client.py @@ -16,7 +16,6 @@ from chia.wallet.trade_record import TradeRecord from chia.wallet.trading.offer import Offer from chia.wallet.transaction_record import TransactionRecord -from chia.wallet.transaction_sorting import SortKey from chia.wallet.util.clvm_streamable import json_deserialize_with_clvm_streamable from chia.wallet.util.query_filter import TransactionTypeFilter from chia.wallet.util.tx_config import CoinSelectionConfig, TXConfig @@ -98,8 +97,12 @@ GetSyncStatusResponse, GetTimestampForHeight, GetTimestampForHeightResponse, + GetTransaction, GetTransactionMemo, GetTransactionMemoResponse, + GetTransactionResponse, + GetTransactions, + GetTransactionsResponse, GetWalletBalance, GetWalletBalanceResponse, GetWalletBalances, @@ -266,43 +269,11 @@ async def get_wallet_balance(self, request: GetWalletBalance) -> GetWalletBalanc async def get_wallet_balances(self, request: GetWalletBalances) -> GetWalletBalancesResponse: return GetWalletBalancesResponse.from_json_dict(await self.fetch("get_wallet_balances", request.to_json_dict())) - async def get_transaction(self, transaction_id: bytes32) -> TransactionRecord: - request = {"transaction_id": transaction_id.hex()} - response = await self.fetch("get_transaction", request) - return TransactionRecord.from_json_dict(response["transaction"]) + async def get_transaction(self, request: GetTransaction) -> GetTransactionResponse: + return GetTransactionResponse.from_json_dict(await self.fetch("get_transaction", request.to_json_dict())) - async def get_transactions( - self, - wallet_id: int, - start: Optional[int] = None, - end: Optional[int] = None, - sort_key: Optional[SortKey] = None, - reverse: bool = False, - to_address: Optional[str] = None, - type_filter: Optional[TransactionTypeFilter] = None, - confirmed: Optional[bool] = None, - ) -> list[TransactionRecord]: - request: dict[str, Any] = {"wallet_id": wallet_id} - - if start is not None: - request["start"] = start - if end is not None: - request["end"] = end - if sort_key is not None: - request["sort_key"] = sort_key.name - request["reverse"] = reverse - - if to_address is not None: - request["to_address"] = to_address - - if type_filter is not None: - request["type_filter"] = type_filter.to_json_dict() - - if confirmed is not None: - request["confirmed"] = confirmed - - res = await self.fetch("get_transactions", request) - return [TransactionRecord.from_json_dict(tx) for tx in res["transactions"]] + async def get_transactions(self, request: GetTransactions) -> GetTransactionsResponse: + return GetTransactionsResponse.from_json_dict(await self.fetch("get_transactions", request.to_json_dict())) async def get_transaction_count( self, wallet_id: int, confirmed: Optional[bool] = None, type_filter: Optional[TransactionTypeFilter] = None