Skip to content

Commit 3cbea35

Browse files
committed
Port extend_derivation_index
1 parent 490c194 commit 3cbea35

File tree

5 files changed

+39
-25
lines changed

5 files changed

+39
-25
lines changed

chia/_tests/wallet/test_wallet_state_manager.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from chia.wallet.transaction_record import TransactionRecord
2222
from chia.wallet.util.transaction_type import TransactionType
2323
from chia.wallet.util.wallet_types import WalletType
24-
from chia.wallet.wallet_request_types import PushTransactions
24+
from chia.wallet.wallet_request_types import ExtendDerivationIndex, PushTransactions
2525
from chia.wallet.wallet_rpc_api import MAX_DERIVATION_INDEX_DELTA
2626
from chia.wallet.wallet_spend_bundle import WalletSpendBundle
2727
from chia.wallet.wallet_state_manager import WalletStateManager
@@ -420,7 +420,7 @@ async def get_puzzle_hash_state() -> PuzzleHashState:
420420
(0,),
421421
)
422422
with pytest.raises(ValueError):
423-
await rpc_client.extend_derivation_index(0)
423+
await rpc_client.extend_derivation_index(ExtendDerivationIndex(uint32(0)))
424424

425425
# Reset to a normal state
426426
await wsm.puzzle_store.delete_wallet(wsm.main_wallet.id())
@@ -431,15 +431,17 @@ async def get_puzzle_hash_state() -> PuzzleHashState:
431431

432432
# Test an index already created
433433
with pytest.raises(ValueError):
434-
await rpc_client.extend_derivation_index(0)
434+
await rpc_client.extend_derivation_index(ExtendDerivationIndex(uint32(0)))
435435

436436
# Test an index too far in the future
437437
with pytest.raises(ValueError):
438-
await rpc_client.extend_derivation_index(MAX_DERIVATION_INDEX_DELTA + expected_state.highest_index + 1)
438+
await rpc_client.extend_derivation_index(
439+
ExtendDerivationIndex(uint32(MAX_DERIVATION_INDEX_DELTA + expected_state.highest_index + 1))
440+
)
439441

440442
# Test the actual functionality
441-
assert await rpc_client.extend_derivation_index(expected_state.highest_index + 5) == str(
442-
expected_state.highest_index + 5
443-
)
443+
assert (
444+
await rpc_client.extend_derivation_index(ExtendDerivationIndex(uint32(expected_state.highest_index + 5)))
445+
).index == expected_state.highest_index + 5
444446
expected_state = PuzzleHashState(expected_state.highest_index + 5, expected_state.used_up_to_index)
445447
assert await get_puzzle_hash_state() == expected_state

chia/cmds/wallet_funcs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
DIDSetWalletName,
5454
DIDTransferDID,
5555
DIDUpdateMetadata,
56+
ExtendDerivationIndex,
5657
FungibleAsset,
5758
GetNextAddress,
5859
GetNotifications,
@@ -443,8 +444,8 @@ async def update_derivation_index(
443444
) -> None:
444445
async with get_wallet_client(root_path, wallet_rpc_port, fp) as (wallet_client, _, _):
445446
print("Updating derivation index... This may take a while.")
446-
res = await wallet_client.extend_derivation_index(index)
447-
print(f"Updated derivation index: {res}")
447+
res = await wallet_client.extend_derivation_index(ExtendDerivationIndex(uint32(index)))
448+
print(f"Updated derivation index: {res.index}")
448449
print("Your balances may take a while to update.")
449450

450451

chia/wallet/wallet_request_types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,18 @@ class GetCurrentDerivationIndexResponse(Streamable):
444444
index: Optional[uint32]
445445

446446

447+
@streamable
448+
@dataclass(frozen=True)
449+
class ExtendDerivationIndex(Streamable):
450+
index: uint32
451+
452+
453+
@streamable
454+
@dataclass(frozen=True)
455+
class ExtendDerivationIndexResponse(Streamable):
456+
index: Optional[uint32]
457+
458+
447459
@streamable
448460
@dataclass(frozen=True)
449461
class GetOffersCountResponse(Streamable):

chia/wallet/wallet_rpc_api.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@
166166
Empty,
167167
ExecuteSigningInstructions,
168168
ExecuteSigningInstructionsResponse,
169+
ExtendDerivationIndex,
170+
ExtendDerivationIndexResponse,
169171
GatherSigningInfo,
170172
GatherSigningInfoResponse,
171173
GenerateMnemonicResponse,
@@ -1879,30 +1881,26 @@ async def get_current_derivation_index(self, request: Empty) -> GetCurrentDeriva
18791881

18801882
return GetCurrentDerivationIndexResponse(index)
18811883

1882-
async def extend_derivation_index(self, request: dict[str, Any]) -> dict[str, Any]:
1884+
@marshal
1885+
async def extend_derivation_index(self, request: ExtendDerivationIndex) -> ExtendDerivationIndexResponse:
18831886
assert self.service.wallet_state_manager is not None
18841887

1885-
# Require a new max derivation index
1886-
if "index" not in request:
1887-
raise ValueError("Derivation index is required")
1888-
18891888
# Require that the wallet is fully synced
18901889
synced = await self.service.wallet_state_manager.synced()
18911890
if synced is False:
18921891
raise ValueError("Wallet needs to be fully synced before extending derivation index")
18931892

1894-
index = uint32(request["index"])
18951893
current: Optional[uint32] = await self.service.wallet_state_manager.puzzle_store.get_last_derivation_path()
18961894

18971895
# Additional sanity check that the wallet is synced
18981896
if current is None:
18991897
raise ValueError("No current derivation record found, unable to extend index")
19001898

19011899
# Require that the new index is greater than the current index
1902-
if index <= current:
1900+
if request.index <= current:
19031901
raise ValueError(f"New derivation index must be greater than current index: {current}")
19041902

1905-
if index - current > MAX_DERIVATION_INDEX_DELTA:
1903+
if request.index - current > MAX_DERIVATION_INDEX_DELTA:
19061904
raise ValueError(
19071905
"Too many derivations requested. "
19081906
f"Use a derivation index less than {current + MAX_DERIVATION_INDEX_DELTA + 1}"
@@ -1912,14 +1910,13 @@ async def extend_derivation_index(self, request: dict[str, Any]) -> dict[str, An
19121910
# to preserve the current last used index, so we call create_more_puzzle_hashes with
19131911
# mark_existing_as_used=False
19141912
result = await self.service.wallet_state_manager.create_more_puzzle_hashes(
1915-
from_zero=False, mark_existing_as_used=False, up_to_index=index, num_additional_phs=0
1913+
from_zero=False, mark_existing_as_used=False, up_to_index=request.index, num_additional_phs=0
19161914
)
19171915
await result.commit(self.service.wallet_state_manager)
19181916

1919-
updated: Optional[uint32] = await self.service.wallet_state_manager.puzzle_store.get_last_derivation_path()
1920-
updated_index = updated if updated is not None else None
1917+
updated_index = await self.service.wallet_state_manager.puzzle_store.get_last_derivation_path()
19211918

1922-
return {"success": True, "index": updated_index}
1919+
return ExtendDerivationIndexResponse(updated_index)
19231920

19241921
@marshal
19251922
async def get_notifications(self, request: GetNotifications) -> GetNotificationsResponse:

chia/wallet/wallet_rpc_client.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@
8282
DLUpdateRootResponse,
8383
ExecuteSigningInstructions,
8484
ExecuteSigningInstructionsResponse,
85+
ExtendDerivationIndex,
86+
ExtendDerivationIndexResponse,
8587
GatherSigningInfo,
8688
GatherSigningInfoResponse,
8789
GenerateMnemonicResponse,
@@ -373,10 +375,10 @@ async def delete_unconfirmed_transactions(self, request: DeleteUnconfirmedTransa
373375
async def get_current_derivation_index(self) -> GetCurrentDerivationIndexResponse:
374376
return GetCurrentDerivationIndexResponse.from_json_dict(await self.fetch("get_current_derivation_index", {}))
375377

376-
async def extend_derivation_index(self, index: int) -> str:
377-
response = await self.fetch("extend_derivation_index", {"index": index})
378-
updated_index = response["index"]
379-
return str(updated_index)
378+
async def extend_derivation_index(self, request: ExtendDerivationIndex) -> ExtendDerivationIndexResponse:
379+
return ExtendDerivationIndexResponse.from_json_dict(
380+
await self.fetch("extend_derivation_index", request.to_json_dict())
381+
)
380382

381383
async def get_farmed_amount(self, include_pool_rewards: bool = False) -> dict[str, Any]:
382384
return await self.fetch("get_farmed_amount", {"include_pool_rewards": include_pool_rewards})

0 commit comments

Comments
 (0)