Skip to content

Commit 2dfe433

Browse files
committed
Port get_transaction_count
1 parent 173bdae commit 2dfe433

File tree

4 files changed

+42
-31
lines changed

4 files changed

+42
-31
lines changed

chia/_tests/wallet/rpc/test_wallet_rpc.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
GetSyncStatusResponse,
124124
GetTimestampForHeight,
125125
GetTransaction,
126+
GetTransactionCount,
126127
GetTransactions,
127128
GetWalletBalance,
128129
GetWalletBalances,
@@ -1096,14 +1097,16 @@ async def test_get_transaction_count(wallet_rpc_environment: WalletRpcTestEnviro
10961097

10971098
all_transactions = (await client.get_transactions(GetTransactions(uint32(1)))).transactions
10981099
assert len(all_transactions) > 0
1099-
transaction_count = await client.get_transaction_count(1)
1100-
assert transaction_count == len(all_transactions)
1101-
transaction_count = await client.get_transaction_count(1, confirmed=False)
1102-
assert transaction_count == 0
1103-
transaction_count = await client.get_transaction_count(
1104-
1, type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CLAWBACK_SEND])
1105-
)
1106-
assert transaction_count == 0
1100+
transaction_count_response = await client.get_transaction_count(GetTransactionCount(uint32(1)))
1101+
assert transaction_count_response.count == len(all_transactions)
1102+
transaction_count_response = await client.get_transaction_count(GetTransactionCount(uint32(1), confirmed=False))
1103+
assert transaction_count_response.count == 0
1104+
transaction_count_response = await client.get_transaction_count(
1105+
GetTransactionCount(
1106+
uint32(1), type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CLAWBACK_SEND])
1107+
)
1108+
)
1109+
assert transaction_count_response.count == 0
11071110

11081111

11091112
@pytest.mark.parametrize(

chia/wallet/wallet_request_types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,21 @@ def from_json_dict(cls, json_dict: dict[str, Any]) -> GetTransactionMemoResponse
397397
)
398398

399399

400+
@streamable
401+
@dataclass(frozen=True)
402+
class GetTransactionCount(Streamable):
403+
wallet_id: uint32
404+
confirmed: Optional[bool] = None
405+
type_filter: Optional[TransactionTypeFilter] = None
406+
407+
408+
@streamable
409+
@dataclass(frozen=True)
410+
class GetTransactionCountResponse(Streamable):
411+
wallet_id: uint32
412+
count: uint16
413+
414+
400415
@streamable
401416
@dataclass(frozen=True)
402417
class GetOffersCountResponse(Streamable):

chia/wallet/wallet_rpc_api.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@
9090
from chia.wallet.util.compute_hints import compute_spend_hints_and_additions
9191
from chia.wallet.util.compute_memos import compute_memos
9292
from chia.wallet.util.curry_and_treehash import NIL_TREEHASH
93-
from chia.wallet.util.query_filter import FilterMode, HashFilter, TransactionTypeFilter
93+
from chia.wallet.util.query_filter import FilterMode, HashFilter
9494
from chia.wallet.util.transaction_type import CLAWBACK_INCOMING_TRANSACTION_TYPES, TransactionType
9595
from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG, TXConfig, TXConfigLoader
9696
from chia.wallet.util.wallet_sync_utils import fetch_coin_spend_for_coin_state
@@ -180,6 +180,8 @@
180180
GetTimestampForHeight,
181181
GetTimestampForHeightResponse,
182182
GetTransaction,
183+
GetTransactionCount,
184+
GetTransactionCountResponse,
183185
GetTransactionMemo,
184186
GetTransactionMemoResponse,
185187
GetTransactionResponse,
@@ -1538,18 +1540,15 @@ async def get_transactions(self, request: GetTransactions) -> GetTransactionsRes
15381540
wallet_id=request.wallet_id,
15391541
)
15401542

1541-
async def get_transaction_count(self, request: dict[str, Any]) -> EndpointResult:
1542-
wallet_id = int(request["wallet_id"])
1543-
type_filter = None
1544-
if "type_filter" in request:
1545-
type_filter = TransactionTypeFilter.from_json_dict(request["type_filter"])
1543+
@marshal
1544+
async def get_transaction_count(self, request: GetTransactionCount) -> GetTransactionCountResponse:
15461545
count = await self.service.wallet_state_manager.tx_store.get_transaction_count_for_wallet(
1547-
wallet_id, confirmed=request.get("confirmed", None), type_filter=type_filter
1546+
request.wallet_id, confirmed=request.confirmed, type_filter=request.type_filter
1547+
)
1548+
return GetTransactionCountResponse(
1549+
request.wallet_id,
1550+
uint16(count),
15481551
)
1549-
return {
1550-
"count": count,
1551-
"wallet_id": wallet_id,
1552-
}
15531552

15541553
async def get_next_address(self, request: dict[str, Any]) -> EndpointResult:
15551554
"""

chia/wallet/wallet_rpc_client.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from chia.wallet.trading.offer import Offer
1818
from chia.wallet.transaction_record import TransactionRecord
1919
from chia.wallet.util.clvm_streamable import json_deserialize_with_clvm_streamable
20-
from chia.wallet.util.query_filter import TransactionTypeFilter
2120
from chia.wallet.util.tx_config import CoinSelectionConfig, TXConfig
2221
from chia.wallet.wallet_coin_store import GetCoinRecords
2322
from chia.wallet.wallet_request_types import (
@@ -98,6 +97,8 @@
9897
GetTimestampForHeight,
9998
GetTimestampForHeightResponse,
10099
GetTransaction,
100+
GetTransactionCount,
101+
GetTransactionCountResponse,
101102
GetTransactionMemo,
102103
GetTransactionMemoResponse,
103104
GetTransactionResponse,
@@ -275,17 +276,10 @@ async def get_transaction(self, request: GetTransaction) -> GetTransactionRespon
275276
async def get_transactions(self, request: GetTransactions) -> GetTransactionsResponse:
276277
return GetTransactionsResponse.from_json_dict(await self.fetch("get_transactions", request.to_json_dict()))
277278

278-
async def get_transaction_count(
279-
self, wallet_id: int, confirmed: Optional[bool] = None, type_filter: Optional[TransactionTypeFilter] = None
280-
) -> int:
281-
request: dict[str, Any] = {"wallet_id": wallet_id}
282-
if type_filter is not None:
283-
request["type_filter"] = type_filter.to_json_dict()
284-
if confirmed is not None:
285-
request["confirmed"] = confirmed
286-
res = await self.fetch("get_transaction_count", request)
287-
# TODO: casting due to lack of type checked deserialization
288-
return cast(int, res["count"])
279+
async def get_transaction_count(self, request: GetTransactionCount) -> GetTransactionCountResponse:
280+
return GetTransactionCountResponse.from_json_dict(
281+
await self.fetch("get_transaction_count", request.to_json_dict())
282+
)
289283

290284
async def get_next_address(self, wallet_id: int, new_address: bool) -> str:
291285
request = {"wallet_id": wallet_id, "new_address": new_address}

0 commit comments

Comments
 (0)