diff --git a/chia/_tests/wallet/cat_wallet/test_cat_wallet.py b/chia/_tests/wallet/cat_wallet/test_cat_wallet.py index 93d04a608149..0440114bce9d 100644 --- a/chia/_tests/wallet/cat_wallet/test_cat_wallet.py +++ b/chia/_tests/wallet/cat_wallet/test_cat_wallet.py @@ -377,9 +377,9 @@ async def test_cat_spend(wallet_environments: WalletTestFramework, wallet_type: if tx_record.spend_bundle is not None: tx_id = tx_record.name assert tx_id is not None - memos = await env_1.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id)) - assert len(memos.coins_with_memos) == 2 - assert cat_2_hash in {coin_w_memos.memos[0] for coin_w_memos in memos.coins_with_memos} + memo_response = await env_1.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id)) + assert len(memo_response.memo_dict) == 2 + assert cat_2_hash in {memos[0] for memos in memo_response.memo_dict.values()} await wallet_environments.process_pending_states( [ @@ -454,9 +454,9 @@ async def test_cat_spend(wallet_environments: WalletTestFramework, wallet_type: assert len(coins) == 1 coin = coins.pop() tx_id = coin.name() - memos = await env_2.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id)) - assert len(memos.coins_with_memos) == 2 - assert cat_2_hash in {coin_w_memos.memos[0] for coin_w_memos in memos.coins_with_memos} + memo_response = await env_2.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id)) + assert len(memo_response.memo_dict) == 2 + assert cat_2_hash in {memos[0] for memos in memo_response.memo_dict.values()} async with cat_wallet.wallet_state_manager.new_action_scope( wallet_environments.tx_config, push=True ) as action_scope: diff --git a/chia/_tests/wallet/test_wallet.py b/chia/_tests/wallet/test_wallet.py index b3f6161d480f..a24258f22559 100644 --- a/chia/_tests/wallet/test_wallet.py +++ b/chia/_tests/wallet/test_wallet.py @@ -1542,9 +1542,9 @@ async def test_wallet_make_transaction_with_memo(self, wallet_environments: Wall fees = estimate_fees(tx.spend_bundle) assert fees == tx_fee - memos = await env_0.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx.name)) - assert len(memos.coins_with_memos) == 1 - assert memos.coins_with_memos[0].memos[0] == ph_2 + memo_response = await env_0.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx.name)) + assert len(memo_response.memo_dict) == 1 + assert next(iter(memo_response.memo_dict.values()))[0] == ph_2 await wallet_environments.process_pending_states( [ @@ -1589,13 +1589,9 @@ async def test_wallet_make_transaction_with_memo(self, wallet_environments: Wall if coin.amount == tx_amount: tx_id = coin.name() assert tx_id is not None - memos = await env_1.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id)) - assert len(memos.coins_with_memos) == 1 - assert memos.coins_with_memos[0].memos[0] == ph_2 - # test json serialization - assert memos.to_json_dict() == { - tx_id.hex(): {memos.coins_with_memos[0].coin_id.hex(): [memos.coins_with_memos[0].memos[0].hex()]} - } + memo_response = await env_1.rpc_client.get_transaction_memo(GetTransactionMemo(transaction_id=tx_id)) + assert len(memo_response.memo_dict) == 1 + assert next(iter(memo_response.memo_dict.values()))[0] == ph_2 @pytest.mark.parametrize( "wallet_environments", diff --git a/chia/wallet/wallet_request_types.py b/chia/wallet/wallet_request_types.py index 2880cda9eb62..10c56d1b78b8 100644 --- a/chia/wallet/wallet_request_types.py +++ b/chia/wallet/wallet_request_types.py @@ -376,36 +376,29 @@ class GetTransactionMemo(Streamable): transaction_id: bytes32 -# utility type for GetTransactionMemoResponse -@streamable -@dataclass(frozen=True) -class CoinIDWithMemos(Streamable): - coin_id: bytes32 - memos: list[bytes] - - @streamable @dataclass(frozen=True) class GetTransactionMemoResponse(Streamable): - transaction_id: bytes32 - coins_with_memos: list[CoinIDWithMemos] + transaction_memos: dict[bytes32, dict[bytes32, list[bytes]]] + + @property + def memo_dict(self) -> dict[bytes32, list[bytes]]: + return next(iter(self.transaction_memos.values())) # TODO: deprecate the kinda silly format of this RPC and delete these functions def to_json_dict(self) -> dict[str, Any]: - return { - self.transaction_id.hex(): { - cwm.coin_id.hex(): [memo.hex() for memo in cwm.memos] for cwm in self.coins_with_memos - } - } + # This is semantically guaranteed but mypy can't know that + return super().to_json_dict()["transaction_memos"] # type: ignore[no-any-return] @classmethod def from_json_dict(cls, json_dict: dict[str, Any]) -> GetTransactionMemoResponse: - return cls( - bytes32.from_hexstr(next(iter(json_dict.keys()))), - [ - CoinIDWithMemos(bytes32.from_hexstr(coin_id), [bytes32.from_hexstr(memo) for memo in memos]) - for coin_id, memos in next(iter(json_dict.values())).items() - ], + return super().from_json_dict( + # We have to filter out the "success" key here + # because it doesn't match our `transaction_memos` hint + # + # We do this by only allowing the keys with "0x" + # which we can assume exist because we serialize all responses + {"transaction_memos": {key: value for key, value in json_dict.items() if key.startswith("0x")}} ) diff --git a/chia/wallet/wallet_rpc_api.py b/chia/wallet/wallet_rpc_api.py index 6324025ad592..a7955f7a1f4a 100644 --- a/chia/wallet/wallet_rpc_api.py +++ b/chia/wallet/wallet_rpc_api.py @@ -180,6 +180,8 @@ GetTimestampForHeight, GetTimestampForHeightResponse, GetTransaction, + GetTransactionMemo, + GetTransactionMemoResponse, GetTransactionResponse, GetTransactions, GetTransactionsResponse, @@ -1317,8 +1319,9 @@ async def get_transaction(self, request: GetTransaction) -> GetTransactionRespon tr.name, ) - async def get_transaction_memo(self, request: dict[str, Any]) -> EndpointResult: - transaction_id: bytes32 = bytes32.from_hexstr(request["transaction_id"]) + @marshal + async def get_transaction_memo(self, request: GetTransactionMemo) -> GetTransactionMemoResponse: + transaction_id: bytes32 = request.transaction_id tr: Optional[TransactionRecord] = await self.service.wallet_state_manager.get_transaction(transaction_id) if tr is None: raise ValueError(f"Transaction 0x{transaction_id.hex()} not found") @@ -1336,12 +1339,7 @@ async def get_transaction_memo(self, request: dict[str, Any]) -> EndpointResult: else: raise ValueError(f"Transaction 0x{transaction_id.hex()} doesn't have any coin spend.") assert tr.spend_bundle is not None - memos: dict[bytes32, list[bytes]] = compute_memos(tr.spend_bundle) - response = {} - # Convert to hex string - for coin_id, memo_list in memos.items(): - response[coin_id.hex()] = [memo.hex() for memo in memo_list] - return {transaction_id.hex(): response} + return GetTransactionMemoResponse({transaction_id: compute_memos(tr.spend_bundle)}) @tx_endpoint(push=False) @marshal