Skip to content

Commit d1d1386

Browse files
committed
Port get_coin_records_by_names
1 parent 7677d00 commit d1d1386

File tree

5 files changed

+64
-51
lines changed

5 files changed

+64
-51
lines changed

chia/_tests/wallet/rpc/test_wallet_rpc.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
DIDTransferDID,
122122
DIDUpdateMetadata,
123123
FungibleAsset,
124+
GetCoinRecordsByNames,
124125
GetNextAddress,
125126
GetNotifications,
126127
GetPrivateKey,
@@ -1504,13 +1505,17 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_
15041505
]
15051506
)
15061507

1507-
test_crs: list[CoinRecord] = await env_1.rpc_client.get_coin_records_by_names(
1508-
[a.name() for a in spend_bundle.additions() if a.amount != 4]
1509-
)
1508+
test_crs: list[CoinRecord] = (
1509+
await env_1.rpc_client.get_coin_records_by_names(
1510+
GetCoinRecordsByNames([a.name() for a in spend_bundle.additions() if a.amount != 4])
1511+
)
1512+
).coin_records
15101513
for cr in test_crs:
15111514
assert cr.coin in spend_bundle.additions()
15121515
with pytest.raises(ValueError):
1513-
await env_1.rpc_client.get_coin_records_by_names([a.name() for a in spend_bundle.additions() if a.amount == 4])
1516+
await env_1.rpc_client.get_coin_records_by_names(
1517+
GetCoinRecordsByNames([a.name() for a in spend_bundle.additions() if a.amount == 4])
1518+
)
15141519
# Create an offer of 5 chia for one CAT
15151520
await env_1.rpc_client.create_offer_for_ids(
15161521
{uint32(1): -5, cat_asset_id.hex(): 1}, wallet_environments.tx_config, validate_only=True
@@ -1855,16 +1860,18 @@ async def test_get_coin_records_by_names(wallet_rpc_environment: WalletRpcTestEn
18551860
assert len(coin_ids_unspent) > 0
18561861
# Do some queries to trigger all parameters
18571862
# 1. Empty coin_ids
1858-
assert await client.get_coin_records_by_names([]) == []
1863+
assert (await client.get_coin_records_by_names(GetCoinRecordsByNames([]))).coin_records == []
18591864
# 2. All coins
1860-
rpc_result = await client.get_coin_records_by_names(coin_ids + coin_ids_unspent)
1861-
assert {record.coin for record in rpc_result} == {*coins, *coins_unspent}
1865+
rpc_result = await client.get_coin_records_by_names(GetCoinRecordsByNames(coin_ids + coin_ids_unspent))
1866+
assert {record.coin for record in rpc_result.coin_records} == {*coins, *coins_unspent}
18621867
# 3. All spent coins
1863-
rpc_result = await client.get_coin_records_by_names(coin_ids, include_spent_coins=True)
1864-
assert {record.coin for record in rpc_result} == coins
1868+
rpc_result = await client.get_coin_records_by_names(GetCoinRecordsByNames(coin_ids, include_spent_coins=True))
1869+
assert {record.coin for record in rpc_result.coin_records} == coins
18651870
# 4. All unspent coins
1866-
rpc_result = await client.get_coin_records_by_names(coin_ids_unspent, include_spent_coins=False)
1867-
assert {record.coin for record in rpc_result} == coins_unspent
1871+
rpc_result = await client.get_coin_records_by_names(
1872+
GetCoinRecordsByNames(coin_ids_unspent, include_spent_coins=False)
1873+
)
1874+
assert {record.coin for record in rpc_result.coin_records} == coins_unspent
18681875
# 5. Filter start/end height
18691876
filter_records = result.records[:10]
18701877
assert len(filter_records) == 10
@@ -1873,11 +1880,13 @@ async def test_get_coin_records_by_names(wallet_rpc_environment: WalletRpcTestEn
18731880
min_height = min(record.confirmed_block_height for record in filter_records)
18741881
max_height = max(record.confirmed_block_height for record in filter_records)
18751882
assert min_height != max_height
1876-
rpc_result = await client.get_coin_records_by_names(filter_coin_ids, start_height=min_height, end_height=max_height)
1877-
assert {record.coin for record in rpc_result} == filter_coins
1883+
rpc_result = await client.get_coin_records_by_names(
1884+
GetCoinRecordsByNames(filter_coin_ids, start_height=min_height, end_height=max_height)
1885+
)
1886+
assert {record.coin for record in rpc_result.coin_records} == filter_coins
18781887
# 8. Test the failure case
18791888
with pytest.raises(ValueError, match="not found"):
1880-
await client.get_coin_records_by_names(coin_ids, include_spent_coins=False)
1889+
await client.get_coin_records_by_names(GetCoinRecordsByNames(coin_ids, include_spent_coins=False))
18811890

18821891

18831892
@pytest.mark.anyio

chia/cmds/coin_funcs.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from chia.wallet.conditions import ConditionValidTimes
1919
from chia.wallet.transaction_record import TransactionRecord
2020
from chia.wallet.util.wallet_types import WalletType
21-
from chia.wallet.wallet_request_types import CombineCoins, GetSpendableCoins, SplitCoins
21+
from chia.wallet.wallet_request_types import CombineCoins, GetCoinRecordsByNames, GetSpendableCoins, SplitCoins
2222

2323

2424
async def async_list(
@@ -222,19 +222,19 @@ async def async_split(
222222
return []
223223

224224
if number_of_coins is None:
225-
coins = await client_info.client.get_coin_records_by_names([target_coin_id])
226-
if len(coins) == 0:
225+
response = await client_info.client.get_coin_records_by_names(GetCoinRecordsByNames([target_coin_id]))
226+
if len(response.coin_records) == 0:
227227
print("Could not find target coin.")
228228
return []
229229
assert amount_per_coin is not None
230-
number_of_coins = int(coins[0].coin.amount // amount_per_coin.convert_amount(mojo_per_unit))
230+
number_of_coins = int(response.coin_records[0].coin.amount // amount_per_coin.convert_amount(mojo_per_unit))
231231
elif amount_per_coin is None:
232-
coins = await client_info.client.get_coin_records_by_names([target_coin_id])
233-
if len(coins) == 0:
232+
response = await client_info.client.get_coin_records_by_names(GetCoinRecordsByNames([target_coin_id]))
233+
if len(response.coin_records) == 0:
234234
print("Could not find target coin.")
235235
return []
236236
assert number_of_coins is not None
237-
amount_per_coin = CliAmount(True, uint64(coins[0].coin.amount // number_of_coins))
237+
amount_per_coin = CliAmount(True, uint64(response.coin_records[0].coin.amount // number_of_coins))
238238

239239
final_amount_per_coin = amount_per_coin.convert_amount(mojo_per_unit)
240240

chia/wallet/wallet_request_types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,21 @@ class GetSpendableCoinsResponse(Streamable):
539539
unconfirmed_additions: list[Coin]
540540

541541

542+
@streamable
543+
@dataclass(frozen=True)
544+
class GetCoinRecordsByNames(Streamable):
545+
names: list[bytes32]
546+
start_height: Optional[uint32] = None
547+
end_height: Optional[uint32] = None
548+
include_spent_coins: bool = False
549+
550+
551+
@streamable
552+
@dataclass(frozen=True)
553+
class GetCoinRecordsByNamesResponse(Streamable):
554+
coin_records: list[CoinRecord]
555+
556+
542557
@streamable
543558
@dataclass(frozen=True)
544559
class GetCurrentDerivationIndexResponse(Streamable):

chia/wallet/wallet_rpc_api.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from chia.types.signing_mode import CHIP_0002_SIGN_MESSAGE_PREFIX, SigningMode
2828
from chia.util.bech32m import decode_puzzle_hash, encode_puzzle_hash
2929
from chia.util.byte_types import hexstr_to_bytes
30-
from chia.util.config import load_config, str2bool
30+
from chia.util.config import load_config
3131
from chia.util.errors import KeychainIsLocked
3232
from chia.util.hash import std_hash
3333
from chia.util.keychain import bytes_to_mnemonic, generate_mnemonic
@@ -172,6 +172,8 @@
172172
GatherSigningInfo,
173173
GatherSigningInfoResponse,
174174
GenerateMnemonicResponse,
175+
GetCoinRecordsByNames,
176+
GetCoinRecordsByNamesResponse,
175177
GetCurrentDerivationIndexResponse,
176178
GetHeightInfoResponse,
177179
GetLoggedInFingerprintResponse,
@@ -1826,39 +1828,37 @@ async def get_spendable_coins(self, request: GetSpendableCoins) -> GetSpendableC
18261828
unconfirmed_additions=unconfirmed_additions,
18271829
)
18281830

1829-
async def get_coin_records_by_names(self, request: dict[str, Any]) -> EndpointResult:
1831+
@marshal
1832+
async def get_coin_records_by_names(self, request: GetCoinRecordsByNames) -> GetCoinRecordsByNamesResponse:
18301833
if await self.service.wallet_state_manager.synced() is False:
18311834
raise ValueError("Wallet needs to be fully synced before finding coin information")
18321835

1833-
if "names" not in request:
1834-
raise ValueError("Names not in request")
1835-
coin_ids = [bytes32.from_hexstr(name) for name in request["names"]]
18361836
kwargs: dict[str, Any] = {
1837-
"coin_id_filter": HashFilter.include(coin_ids),
1837+
"coin_id_filter": HashFilter.include(request.names),
18381838
}
18391839

18401840
confirmed_range = UInt32Range()
1841-
if "start_height" in request:
1842-
confirmed_range = dataclasses.replace(confirmed_range, start=uint32(request["start_height"]))
1843-
if "end_height" in request:
1844-
confirmed_range = dataclasses.replace(confirmed_range, stop=uint32(request["end_height"]))
1841+
if request.start_height is not None:
1842+
confirmed_range = dataclasses.replace(confirmed_range, start=request.start_height)
1843+
if request.end_height is not None:
1844+
confirmed_range = dataclasses.replace(confirmed_range, stop=request.end_height)
18451845
if confirmed_range != UInt32Range():
18461846
kwargs["confirmed_range"] = confirmed_range
18471847

1848-
if "include_spent_coins" in request and not str2bool(request["include_spent_coins"]):
1848+
if request.include_spent_coins:
18491849
kwargs["spent_range"] = unspent_range
18501850

18511851
async with self.service.wallet_state_manager.lock:
18521852
coin_records: list[CoinRecord] = await self.service.wallet_state_manager.get_coin_records_by_coin_ids(
18531853
**kwargs
18541854
)
18551855
missed_coins: list[str] = [
1856-
"0x" + c_id.hex() for c_id in coin_ids if c_id not in [cr.name for cr in coin_records]
1856+
"0x" + c_id.hex() for c_id in request.names if c_id not in [cr.name for cr in coin_records]
18571857
]
18581858
if missed_coins:
18591859
raise ValueError(f"Coin ID's: {missed_coins} not found.")
18601860

1861-
return {"coin_records": [cr.to_json_dict() for cr in coin_records]}
1861+
return GetCoinRecordsByNamesResponse(coin_records)
18621862

18631863
@marshal
18641864
async def get_current_derivation_index(self, request: Empty) -> GetCurrentDerivationIndexResponse:

chia/wallet/wallet_rpc_client.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from chia.rpc.rpc_client import RpcClient
1010
from chia.types.blockchain_format.coin import Coin
1111
from chia.types.blockchain_format.program import Program
12-
from chia.types.coin_record import CoinRecord
1312
from chia.wallet.conditions import Condition, ConditionValidTimes, conditions_to_json_dicts
1413
from chia.wallet.puzzles.clawback.metadata import AutoClaimSettings
1514
from chia.wallet.trade_record import TradeRecord
@@ -88,6 +87,8 @@
8887
GatherSigningInfoResponse,
8988
GenerateMnemonicResponse,
9089
GetCATListResponse,
90+
GetCoinRecordsByNames,
91+
GetCoinRecordsByNamesResponse,
9192
GetCurrentDerivationIndexResponse,
9293
GetHeightInfoResponse,
9394
GetLoggedInFingerprintResponse,
@@ -418,22 +419,10 @@ async def get_coin_records(self, request: GetCoinRecords) -> dict[str, Any]:
418419
async def get_spendable_coins(self, request: GetSpendableCoins) -> GetSpendableCoinsResponse:
419420
return GetSpendableCoinsResponse.from_json_dict(await self.fetch("get_spendable_coins", request.to_json_dict()))
420421

421-
async def get_coin_records_by_names(
422-
self,
423-
names: list[bytes32],
424-
include_spent_coins: bool = True,
425-
start_height: Optional[int] = None,
426-
end_height: Optional[int] = None,
427-
) -> list[CoinRecord]:
428-
names_hex = [name.hex() for name in names]
429-
request = {"names": names_hex, "include_spent_coins": include_spent_coins}
430-
if start_height is not None:
431-
request["start_height"] = start_height
432-
if end_height is not None:
433-
request["end_height"] = end_height
434-
435-
response = await self.fetch("get_coin_records_by_names", request)
436-
return [CoinRecord.from_json_dict(cr) for cr in response["coin_records"]]
422+
async def get_coin_records_by_names(self, request: GetCoinRecordsByNames) -> GetCoinRecordsByNamesResponse:
423+
return GetCoinRecordsByNamesResponse.from_json_dict(
424+
await self.fetch("get_coin_records_by_names", request.to_json_dict())
425+
)
437426

438427
# DID wallet
439428
async def create_new_did_wallet(

0 commit comments

Comments
 (0)