Skip to content

Commit 71e1c7c

Browse files
authored
[CHIA-3599] Port get_transaction_memo to @marshal (#19940)
* Port `get_transaction_memo` to `@marshal` * comment requested by @altendky
1 parent c700d14 commit 71e1c7c

File tree

4 files changed

+32
-45
lines changed

4 files changed

+32
-45
lines changed

chia/_tests/wallet/cat_wallet/test_cat_wallet.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,9 @@ async def test_cat_spend(wallet_environments: WalletTestFramework, wallet_type:
377377
if tx_record.spend_bundle is not None:
378378
tx_id = tx_record.name
379379
assert tx_id is not None
380-
memos = await env_1.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id))
381-
assert len(memos.coins_with_memos) == 2
382-
assert cat_2_hash in {coin_w_memos.memos[0] for coin_w_memos in memos.coins_with_memos}
380+
memo_response = await env_1.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id))
381+
assert len(memo_response.memo_dict) == 2
382+
assert cat_2_hash in {memos[0] for memos in memo_response.memo_dict.values()}
383383

384384
await wallet_environments.process_pending_states(
385385
[
@@ -454,9 +454,9 @@ async def test_cat_spend(wallet_environments: WalletTestFramework, wallet_type:
454454
assert len(coins) == 1
455455
coin = coins.pop()
456456
tx_id = coin.name()
457-
memos = await env_2.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id))
458-
assert len(memos.coins_with_memos) == 2
459-
assert cat_2_hash in {coin_w_memos.memos[0] for coin_w_memos in memos.coins_with_memos}
457+
memo_response = await env_2.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id))
458+
assert len(memo_response.memo_dict) == 2
459+
assert cat_2_hash in {memos[0] for memos in memo_response.memo_dict.values()}
460460
async with cat_wallet.wallet_state_manager.new_action_scope(
461461
wallet_environments.tx_config, push=True
462462
) as action_scope:

chia/_tests/wallet/test_wallet.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,9 +1542,9 @@ async def test_wallet_make_transaction_with_memo(self, wallet_environments: Wall
15421542
fees = estimate_fees(tx.spend_bundle)
15431543
assert fees == tx_fee
15441544

1545-
memos = await env_0.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx.name))
1546-
assert len(memos.coins_with_memos) == 1
1547-
assert memos.coins_with_memos[0].memos[0] == ph_2
1545+
memo_response = await env_0.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx.name))
1546+
assert len(memo_response.memo_dict) == 1
1547+
assert next(iter(memo_response.memo_dict.values()))[0] == ph_2
15481548

15491549
await wallet_environments.process_pending_states(
15501550
[
@@ -1589,13 +1589,9 @@ async def test_wallet_make_transaction_with_memo(self, wallet_environments: Wall
15891589
if coin.amount == tx_amount:
15901590
tx_id = coin.name()
15911591
assert tx_id is not None
1592-
memos = await env_1.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id))
1593-
assert len(memos.coins_with_memos) == 1
1594-
assert memos.coins_with_memos[0].memos[0] == ph_2
1595-
# test json serialization
1596-
assert memos.to_json_dict() == {
1597-
tx_id.hex(): {memos.coins_with_memos[0].coin_id.hex(): [memos.coins_with_memos[0].memos[0].hex()]}
1598-
}
1592+
memo_response = await env_1.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id))
1593+
assert len(memo_response.memo_dict) == 1
1594+
assert next(iter(memo_response.memo_dict.values()))[0] == ph_2
15991595

16001596
@pytest.mark.parametrize(
16011597
"wallet_environments",

chia/wallet/wallet_request_types.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -376,36 +376,29 @@ class GetTransactionMemo(Streamable):
376376
transaction_id: bytes32
377377

378378

379-
# utility type for GetTransactionMemoResponse
380-
@streamable
381-
@dataclass(frozen=True)
382-
class CoinIDWithMemos(Streamable):
383-
coin_id: bytes32
384-
memos: list[bytes]
385-
386-
387379
@streamable
388380
@dataclass(frozen=True)
389381
class GetTransactionMemoResponse(Streamable):
390-
transaction_id: bytes32
391-
coins_with_memos: list[CoinIDWithMemos]
382+
transaction_memos: dict[bytes32, dict[bytes32, list[bytes]]]
383+
384+
@property
385+
def memo_dict(self) -> dict[bytes32, list[bytes]]:
386+
return next(iter(self.transaction_memos.values()))
392387

393388
# TODO: deprecate the kinda silly format of this RPC and delete these functions
394389
def to_json_dict(self) -> dict[str, Any]:
395-
return {
396-
self.transaction_id.hex(): {
397-
cwm.coin_id.hex(): [memo.hex() for memo in cwm.memos] for cwm in self.coins_with_memos
398-
}
399-
}
390+
# This is semantically guaranteed but mypy can't know that
391+
return super().to_json_dict()["transaction_memos"] # type: ignore[no-any-return]
400392

401393
@classmethod
402394
def from_json_dict(cls, json_dict: dict[str, Any]) -> GetTransactionMemoResponse:
403-
return cls(
404-
bytes32.from_hexstr(next(iter(json_dict.keys()))),
405-
[
406-
CoinIDWithMemos(bytes32.from_hexstr(coin_id), [bytes32.from_hexstr(memo) for memo in memos])
407-
for coin_id, memos in next(iter(json_dict.values())).items()
408-
],
395+
return super().from_json_dict(
396+
# We have to filter out the "success" key here
397+
# because it doesn't match our `transaction_memos` hint
398+
#
399+
# We do this by only allowing the keys with "0x"
400+
# which we can assume exist because we serialize all responses
401+
{"transaction_memos": {key: value for key, value in json_dict.items() if key.startswith("0x")}}
409402
)
410403

411404

chia/wallet/wallet_rpc_api.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@
180180
GetTimestampForHeight,
181181
GetTimestampForHeightResponse,
182182
GetTransaction,
183+
GetTransactionMemo,
184+
GetTransactionMemoResponse,
183185
GetTransactionResponse,
184186
GetTransactions,
185187
GetTransactionsResponse,
@@ -1317,8 +1319,9 @@ async def get_transaction(self, request: GetTransaction) -> GetTransactionRespon
13171319
tr.name,
13181320
)
13191321

1320-
async def get_transaction_memo(self, request: dict[str, Any]) -> EndpointResult:
1321-
transaction_id: bytes32 = bytes32.from_hexstr(request["transaction_id"])
1322+
@marshal
1323+
async def get_transaction_memo(self, request: GetTransactionMemo) -> GetTransactionMemoResponse:
1324+
transaction_id: bytes32 = request.transaction_id
13221325
tr: Optional[TransactionRecord] = await self.service.wallet_state_manager.get_transaction(transaction_id)
13231326
if tr is None:
13241327
raise ValueError(f"Transaction 0x{transaction_id.hex()} not found")
@@ -1336,12 +1339,7 @@ async def get_transaction_memo(self, request: dict[str, Any]) -> EndpointResult:
13361339
else:
13371340
raise ValueError(f"Transaction 0x{transaction_id.hex()} doesn't have any coin spend.")
13381341
assert tr.spend_bundle is not None
1339-
memos: dict[bytes32, list[bytes]] = compute_memos(tr.spend_bundle)
1340-
response = {}
1341-
# Convert to hex string
1342-
for coin_id, memo_list in memos.items():
1343-
response[coin_id.hex()] = [memo.hex() for memo in memo_list]
1344-
return {transaction_id.hex(): response}
1342+
return GetTransactionMemoResponse({transaction_id: compute_memos(tr.spend_bundle)})
13451343

13461344
@tx_endpoint(push=False)
13471345
@marshal

0 commit comments

Comments
 (0)