Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions chia/_tests/wallet/cat_wallet/test_cat_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,9 @@ async def test_cat_spend(wallet_environments: WalletTestFramework, wallet_type:
if tx_record.spend_bundle is not None:
tx_id = tx_record.name
assert tx_id is not None
memos = await env_1.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id))
assert len(memos.coins_with_memos) == 2
assert cat_2_hash in {coin_w_memos.memos[0] for coin_w_memos in memos.coins_with_memos}
memo_response = await env_1.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id))
assert len(memo_response.memo_dict) == 2
assert cat_2_hash in {memos[0] for memos in memo_response.memo_dict.values()}

await wallet_environments.process_pending_states(
[
Expand Down Expand Up @@ -454,9 +454,9 @@ async def test_cat_spend(wallet_environments: WalletTestFramework, wallet_type:
assert len(coins) == 1
coin = coins.pop()
tx_id = coin.name()
memos = await env_2.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id))
assert len(memos.coins_with_memos) == 2
assert cat_2_hash in {coin_w_memos.memos[0] for coin_w_memos in memos.coins_with_memos}
memo_response = await env_2.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id))
assert len(memo_response.memo_dict) == 2
assert cat_2_hash in {memos[0] for memos in memo_response.memo_dict.values()}
async with cat_wallet.wallet_state_manager.new_action_scope(
wallet_environments.tx_config, push=True
) as action_scope:
Expand Down
16 changes: 6 additions & 10 deletions chia/_tests/wallet/test_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,9 +1542,9 @@ async def test_wallet_make_transaction_with_memo(self, wallet_environments: Wall
fees = estimate_fees(tx.spend_bundle)
assert fees == tx_fee

memos = await env_0.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx.name))
assert len(memos.coins_with_memos) == 1
assert memos.coins_with_memos[0].memos[0] == ph_2
memo_response = await env_0.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx.name))
assert len(memo_response.memo_dict) == 1
assert next(iter(memo_response.memo_dict.values()))[0] == ph_2

await wallet_environments.process_pending_states(
[
Expand Down Expand Up @@ -1589,13 +1589,9 @@ async def test_wallet_make_transaction_with_memo(self, wallet_environments: Wall
if coin.amount == tx_amount:
tx_id = coin.name()
assert tx_id is not None
memos = await env_1.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id))
assert len(memos.coins_with_memos) == 1
assert memos.coins_with_memos[0].memos[0] == ph_2
# test json serialization
assert memos.to_json_dict() == {
tx_id.hex(): {memos.coins_with_memos[0].coin_id.hex(): [memos.coins_with_memos[0].memos[0].hex()]}
}
memo_response = await env_1.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id))
assert len(memo_response.memo_dict) == 1
assert next(iter(memo_response.memo_dict.values()))[0] == ph_2

@pytest.mark.parametrize(
"wallet_environments",
Expand Down
35 changes: 14 additions & 21 deletions chia/wallet/wallet_request_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,36 +376,29 @@ class GetTransactionMemo(Streamable):
transaction_id: bytes32


# utility type for GetTransactionMemoResponse
@streamable
@dataclass(frozen=True)
class CoinIDWithMemos(Streamable):
coin_id: bytes32
memos: list[bytes]


@streamable
@dataclass(frozen=True)
class GetTransactionMemoResponse(Streamable):
transaction_id: bytes32
coins_with_memos: list[CoinIDWithMemos]
transaction_memos: dict[bytes32, dict[bytes32, list[bytes]]]

@property
def memo_dict(self) -> dict[bytes32, list[bytes]]:
return next(iter(self.transaction_memos.values()))

# TODO: deprecate the kinda silly format of this RPC and delete these functions
def to_json_dict(self) -> dict[str, Any]:
return {
self.transaction_id.hex(): {
cwm.coin_id.hex(): [memo.hex() for memo in cwm.memos] for cwm in self.coins_with_memos
}
}
# This is semantically guaranteed but mypy can't know that
return super().to_json_dict()["transaction_memos"] # type: ignore[no-any-return]

@classmethod
def from_json_dict(cls, json_dict: dict[str, Any]) -> GetTransactionMemoResponse:
return cls(
bytes32.from_hexstr(next(iter(json_dict.keys()))),
[
CoinIDWithMemos(bytes32.from_hexstr(coin_id), [bytes32.from_hexstr(memo) for memo in memos])
for coin_id, memos in next(iter(json_dict.values())).items()
],
return super().from_json_dict(
# We have to filter out the "success" key here
# because it doesn't match our `transaction_memos` hint
#
# We do this by only allowing the keys with "0x"
# which we can assume exist because we serialize all responses
{"transaction_memos": {key: value for key, value in json_dict.items() if key.startswith("0x")}}
)


Expand Down
14 changes: 6 additions & 8 deletions chia/wallet/wallet_rpc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@
GetTimestampForHeight,
GetTimestampForHeightResponse,
GetTransaction,
GetTransactionMemo,
GetTransactionMemoResponse,
GetTransactionResponse,
GetTransactions,
GetTransactionsResponse,
Expand Down Expand Up @@ -1317,8 +1319,9 @@ async def get_transaction(self, request: GetTransaction) -> GetTransactionRespon
tr.name,
)

async def get_transaction_memo(self, request: dict[str, Any]) -> EndpointResult:
transaction_id: bytes32 = bytes32.from_hexstr(request["transaction_id"])
@marshal
async def get_transaction_memo(self, request: GetTransactionMemo) -> GetTransactionMemoResponse:
transaction_id: bytes32 = request.transaction_id
tr: Optional[TransactionRecord] = await self.service.wallet_state_manager.get_transaction(transaction_id)
if tr is None:
raise ValueError(f"Transaction 0x{transaction_id.hex()} not found")
Expand All @@ -1336,12 +1339,7 @@ async def get_transaction_memo(self, request: dict[str, Any]) -> EndpointResult:
else:
raise ValueError(f"Transaction 0x{transaction_id.hex()} doesn't have any coin spend.")
assert tr.spend_bundle is not None
memos: dict[bytes32, list[bytes]] = compute_memos(tr.spend_bundle)
response = {}
# Convert to hex string
for coin_id, memo_list in memos.items():
response[coin_id.hex()] = [memo.hex() for memo in memo_list]
return {transaction_id.hex(): response}
return GetTransactionMemoResponse({transaction_id: compute_memos(tr.spend_bundle)})

@tx_endpoint(push=False)
@marshal
Expand Down
Loading