From 53cebb1504c1eebd983a0d93d55d912c209c85d9 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 12 Sep 2025 11:41:03 -0700 Subject: [PATCH 1/2] Port `get_offer` to `@marshal` --- chia/_tests/cmds/wallet/test_wallet.py | 35 +++++++++++++---------- chia/_tests/wallet/rpc/test_wallet_rpc.py | 9 +++--- chia/cmds/wallet_funcs.py | 5 ++-- chia/data_layer/data_layer.py | 5 +++- chia/wallet/wallet_request_types.py | 27 +++++++++++++++++ chia/wallet/wallet_rpc_api.py | 18 +++++++----- chia/wallet/wallet_rpc_client.py | 8 +++--- 7 files changed, 74 insertions(+), 33 deletions(-) diff --git a/chia/_tests/cmds/wallet/test_wallet.py b/chia/_tests/cmds/wallet/test_wallet.py index 981dfea8dbd6..1de6ae3163e0 100644 --- a/chia/_tests/cmds/wallet/test_wallet.py +++ b/chia/_tests/cmds/wallet/test_wallet.py @@ -60,6 +60,8 @@ GetHeightInfoResponse, GetNextAddress, GetNextAddressResponse, + GetOffer, + GetOfferResponse, GetTransaction, GetTransactions, GetTransactionsResponse, @@ -1163,22 +1165,25 @@ def test_cancel_offer(capsys: object, get_test_cli_clients: tuple[TestRpcClients # set RPC Client class CancelOfferRpcClient(TestWalletRpcClient): - async def get_offer(self, trade_id: bytes32, file_contents: bool = False) -> TradeRecord: - self.add_to_log("get_offer", (trade_id, file_contents)) + async def get_offer(self, request: GetOffer) -> GetOfferResponse: + self.add_to_log("get_offer", (request.trade_id, request.file_contents)) offer = Offer.from_bech32(test_offer_file_bech32) - return TradeRecord( - confirmed_at_index=uint32(0), - accepted_at_time=uint64(0), - created_at_time=uint64(12345678), - is_my_offer=True, - sent=uint32(0), - sent_to=[], - offer=bytes(offer), - taken_offer=None, - coins_of_interest=offer.get_involved_coins(), - trade_id=offer.name(), - status=uint32(TradeStatus.PENDING_ACCEPT.value), - valid_times=ConditionValidTimes(), + return GetOfferResponse( + test_offer_file_bech32, + TradeRecord( + confirmed_at_index=uint32(0), + accepted_at_time=uint64(0), + created_at_time=uint64(12345678), + is_my_offer=True, + sent=uint32(0), + sent_to=[], + offer=bytes(offer), + taken_offer=None, + coins_of_interest=offer.get_involved_coins(), + trade_id=offer.name(), + status=uint32(TradeStatus.PENDING_ACCEPT.value), + valid_times=ConditionValidTimes(), + ), ) async def cancel_offer( diff --git a/chia/_tests/wallet/rpc/test_wallet_rpc.py b/chia/_tests/wallet/rpc/test_wallet_rpc.py index 459060da0d84..54351842cf18 100644 --- a/chia/_tests/wallet/rpc/test_wallet_rpc.py +++ b/chia/_tests/wallet/rpc/test_wallet_rpc.py @@ -133,6 +133,7 @@ GetCoinRecordsByNames, GetNextAddress, GetNotifications, + GetOffer, GetOfferSummary, GetPrivateKey, GetSpendableCoins, @@ -1622,7 +1623,7 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_ await env_1.rpc_client.cancel_offer(offer.name(), wallet_environments.tx_config, secure=False) - trade_record = await env_1.rpc_client.get_offer(offer.name(), file_contents=True) + trade_record = (await env_1.rpc_client.get_offer(GetOffer(offer.name(), file_contents=True))).trade_record assert trade_record.offer == bytes(offer) assert TradeStatus(trade_record.status) == TradeStatus.CANCELLED @@ -1630,7 +1631,7 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_ offer.name(), wallet_environments.tx_config, fee=uint64(1), secure=True ) - trade_record = await env_1.rpc_client.get_offer(offer.name()) + trade_record = (await env_1.rpc_client.get_offer(GetOffer(offer.name()))).trade_record assert TradeStatus(trade_record.status) == TradeStatus.PENDING_CANCEL create_res = await env_1.rpc_client.create_offer_for_ids( @@ -1717,7 +1718,7 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_ ) async def is_trade_confirmed(client: WalletRpcClient, offer: Offer) -> bool: - trade_record = await client.get_offer(offer.name()) + trade_record = (await client.get_offer(GetOffer(offer.name()))).trade_record return TradeStatus(trade_record.status) == TradeStatus.CONFIRMED await time_out_assert(15, is_trade_confirmed, True, env_1.rpc_client, offer) @@ -1726,7 +1727,7 @@ async def is_trade_confirmed(client: WalletRpcClient, offer: Offer) -> bool: def only_ids(trades: list[TradeRecord]) -> list[bytes32]: return [t.trade_id for t in trades] - trade_record = await env_1.rpc_client.get_offer(offer.name()) + trade_record = (await env_1.rpc_client.get_offer(GetOffer(offer.name()))).trade_record all_offers = await env_1.rpc_client.get_all_offers(include_completed=True) # confirmed at index descending assert len(all_offers) == 2 assert only_ids(all_offers) == only_ids([trade_record, new_trade_record]) diff --git a/chia/cmds/wallet_funcs.py b/chia/cmds/wallet_funcs.py index 1b2ae3d91dad..a8c3f7559990 100644 --- a/chia/cmds/wallet_funcs.py +++ b/chia/cmds/wallet_funcs.py @@ -66,6 +66,7 @@ FungibleAsset, GetNextAddress, GetNotifications, + GetOffer, GetTransaction, GetTransactions, GetWalletBalance, @@ -787,7 +788,7 @@ async def get_offers( start = end end += batch_size else: - records = [await wallet_client.get_offer(offer_id, file_contents)] + records = [(await wallet_client.get_offer(GetOffer(offer_id, file_contents))).trade_record] if filepath is not None: with open(pathlib.Path(filepath), "w") as file: file.write(Offer.from_bytes(records[0].offer).to_bech32()) @@ -919,7 +920,7 @@ async def cancel_offer( condition_valid_times: ConditionValidTimes, ) -> list[TransactionRecord]: async with get_wallet_client(root_path, wallet_rpc_port, fp) as (wallet_client, fingerprint, config): - trade_record = await wallet_client.get_offer(offer_id, file_contents=True) + trade_record = (await wallet_client.get_offer(GetOffer(offer_id, file_contents=True))).trade_record await print_trade_record(trade_record, wallet_client, summaries=True) cli_confirm(f"Are you sure you wish to cancel offer with ID: {trade_record.trade_id}? (y/n): ") diff --git a/chia/data_layer/data_layer.py b/chia/data_layer/data_layer.py index 03d636da4adb..b5a018682587 100644 --- a/chia/data_layer/data_layer.py +++ b/chia/data_layer/data_layer.py @@ -81,6 +81,7 @@ DLUpdateMultiple, DLUpdateMultipleUpdates, DLUpdateRoot, + GetOffer, LauncherRootPair, LogIn, TakeOffer, @@ -1274,7 +1275,9 @@ async def cancel_offer(self, trade_id: bytes32, secure: bool, fee: uint64) -> No store_ids: list[bytes32] = [] if not secure: - trade_record = await self.wallet_rpc.get_offer(trade_id=trade_id, file_contents=True) + trade_record = ( + await self.wallet_rpc.get_offer(GetOffer(trade_id=trade_id, file_contents=True)) + ).trade_record trading_offer = TradingOffer.from_bytes(trade_record.offer) summary = await DataLayerWallet.get_offer_summary(offer=trading_offer) store_ids = [offered.launcher_id for offered in summary.offered] diff --git a/chia/wallet/wallet_request_types.py b/chia/wallet/wallet_request_types.py index 3e97b206704b..728bec281adb 100644 --- a/chia/wallet/wallet_request_types.py +++ b/chia/wallet/wallet_request_types.py @@ -1955,6 +1955,33 @@ class TakeOfferResponse(_OfferEndpointResponse): # Inheriting for de-dup sake pass +@streamable +@dataclass(frozen=True) +class GetOffer(Streamable): + trade_id: bytes32 + file_contents: bool = False + + +@streamable +@dataclass(frozen=True) +class GetOfferResponse(Streamable): + offer: Optional[str] + trade_record: TradeRecord + + def to_json_dict(self) -> dict[str, Any]: + return {**super().to_json_dict(), "trade_record": self.trade_record.to_json_dict_convenience()} + + @classmethod + def from_json_dict(cls, json_dict: dict[str, Any]) -> Self: + return cls( + offer=json_dict["offer"], + trade_record=TradeRecord.from_json_dict_convenience( + json_dict["trade_record"], + bytes(Offer.from_bech32(json_dict["offer"])).hex() if json_dict["offer"] is not None else "", + ), + ) + + @streamable @dataclass(frozen=True) class CancelOfferResponse(TransactionEndpointResponse): diff --git a/chia/wallet/wallet_rpc_api.py b/chia/wallet/wallet_rpc_api.py index 7294360e318c..e332240907d6 100644 --- a/chia/wallet/wallet_rpc_api.py +++ b/chia/wallet/wallet_rpc_api.py @@ -197,6 +197,8 @@ GetNextAddressResponse, GetNotifications, GetNotificationsResponse, + GetOffer, + GetOfferResponse, GetOffersCountResponse, GetOfferSummary, GetOfferSummaryResponse, @@ -2355,18 +2357,20 @@ async def take_offer( trade_record, ) - async def get_offer(self, request: dict[str, Any]) -> EndpointResult: + @marshal + async def get_offer(self, request: GetOffer) -> GetOfferResponse: trade_mgr = self.service.wallet_state_manager.trade_manager - trade_id = bytes32.from_hexstr(request["trade_id"]) - file_contents: bool = request.get("file_contents", False) - trade_record: Optional[TradeRecord] = await trade_mgr.get_trade_by_id(bytes32(trade_id)) + trade_record: Optional[TradeRecord] = await trade_mgr.get_trade_by_id(request.trade_id) if trade_record is None: - raise ValueError(f"No trade with trade id: {trade_id.hex()}") + raise ValueError(f"No trade with trade id: {request.trade_id.hex()}") offer_to_return: bytes = trade_record.offer if trade_record.taken_offer is None else trade_record.taken_offer - offer_value: Optional[str] = Offer.from_bytes(offer_to_return).to_bech32() if file_contents else None - return {"trade_record": trade_record.to_json_dict_convenience(), "offer": offer_value} + offer: Optional[str] = Offer.from_bytes(offer_to_return).to_bech32() if request.file_contents else None + return GetOfferResponse( + offer, + trade_record, + ) async def get_all_offers(self, request: dict[str, Any]) -> EndpointResult: trade_mgr = self.service.wallet_state_manager.trade_manager diff --git a/chia/wallet/wallet_rpc_client.py b/chia/wallet/wallet_rpc_client.py index 6825ebc64332..499c1cc24186 100644 --- a/chia/wallet/wallet_rpc_client.py +++ b/chia/wallet/wallet_rpc_client.py @@ -107,6 +107,8 @@ GetNextAddressResponse, GetNotifications, GetNotificationsResponse, + GetOffer, + GetOfferResponse, GetOffersCountResponse, GetOfferSummary, GetOfferSummaryResponse, @@ -702,10 +704,8 @@ async def take_offer( ) ) - async def get_offer(self, trade_id: bytes32, file_contents: bool = False) -> TradeRecord: - res = await self.fetch("get_offer", {"trade_id": trade_id.hex(), "file_contents": file_contents}) - offer_str = bytes(Offer.from_bech32(res["offer"])).hex() if file_contents else "" - return TradeRecord.from_json_dict_convenience(res["trade_record"], offer_str) + async def get_offer(self, request: GetOffer) -> GetOfferResponse: + return GetOfferResponse.from_json_dict(await self.fetch("get_offer", request.to_json_dict())) async def get_all_offers( self, From 9ad688090cf0beded11d8050f0c6b38d7dcfb91c Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 12 Sep 2025 12:09:06 -0700 Subject: [PATCH 2/2] Port `get_all_offers` to `@marshal` --- chia/_tests/cmds/wallet/test_wallet.py | 34 +++----- chia/_tests/wallet/rpc/test_wallet_rpc.py | 102 +++++++++++++++++----- chia/cmds/wallet_funcs.py | 25 +++--- chia/wallet/wallet_request_types.py | 36 ++++++++ chia/wallet/wallet_rpc_api.py | 42 +++++---- chia/wallet/wallet_rpc_client.py | 40 +-------- 6 files changed, 169 insertions(+), 110 deletions(-) diff --git a/chia/_tests/cmds/wallet/test_wallet.py b/chia/_tests/cmds/wallet/test_wallet.py index 1de6ae3163e0..38e7a4602809 100644 --- a/chia/_tests/cmds/wallet/test_wallet.py +++ b/chia/_tests/cmds/wallet/test_wallet.py @@ -56,6 +56,8 @@ ExtendDerivationIndex, ExtendDerivationIndexResponse, FungibleAsset, + GetAllOffers, + GetAllOffersResponse, GetCurrentDerivationIndexResponse, GetHeightInfoResponse, GetNextAddress, @@ -948,32 +950,22 @@ def test_get_offers(capsys: object, get_test_cli_clients: tuple[TestRpcClients, # set RPC Client class GetOffersRpcClient(TestWalletRpcClient): - async def get_all_offers( - self, - start: int = 0, - end: int = 50, - sort_key: Optional[str] = None, - reverse: bool = False, - file_contents: bool = False, - exclude_my_offers: bool = False, - exclude_taken_offers: bool = False, - include_completed: bool = False, - ) -> list[TradeRecord]: + async def get_all_offers(self, request: GetAllOffers) -> GetAllOffersResponse: self.add_to_log( "get_all_offers", ( - start, - end, - sort_key, - reverse, - file_contents, - exclude_my_offers, - exclude_taken_offers, - include_completed, + request.start, + request.end, + request.sort_key, + request.reverse, + request.file_contents, + request.exclude_my_offers, + request.exclude_taken_offers, + request.include_completed, ), ) records: list[TradeRecord] = [] - for i in reversed(range(start, end - 1)): # reversed to match the sort order + for i in reversed(range(request.start, request.end - 1)): # reversed to match the sort order trade_offer = TradeRecord( confirmed_at_index=uint32(0), accepted_at_time=None, @@ -997,7 +989,7 @@ async def get_all_offers( ), ) records.append(trade_offer) - return records + return GetAllOffersResponse([], records) inst_rpc_client = GetOffersRpcClient() test_rpc_clients.wallet_rpc_client = inst_rpc_client diff --git a/chia/_tests/wallet/rpc/test_wallet_rpc.py b/chia/_tests/wallet/rpc/test_wallet_rpc.py index 54351842cf18..e83238e0c502 100644 --- a/chia/_tests/wallet/rpc/test_wallet_rpc.py +++ b/chia/_tests/wallet/rpc/test_wallet_rpc.py @@ -130,6 +130,7 @@ DIDTransferDID, DIDUpdateMetadata, FungibleAsset, + GetAllOffers, GetCoinRecordsByNames, GetNextAddress, GetNotifications, @@ -1555,7 +1556,7 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_ CreateOfferForIDs(offer={str(1): "-5", cat_asset_id.hex(): "1"}, validate_only=True), tx_config=wallet_environments.tx_config, ) - all_offers = await env_1.rpc_client.get_all_offers() + all_offers = (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records assert len(all_offers) == 0 driver_dict = { @@ -1599,7 +1600,7 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_ assert offer_validity_response.id == offer.name() assert offer_validity_response.valid - all_offers = await env_1.rpc_client.get_all_offers(file_contents=True) + all_offers = (await env_1.rpc_client.get_all_offers(GetAllOffers(file_contents=True))).trade_records assert len(all_offers) == 1 assert TradeStatus(all_offers[0].status) == TradeStatus.PENDING_ACCEPT assert all_offers[0].offer == bytes(offer) @@ -1638,7 +1639,7 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_ CreateOfferForIDs(offer={str(1): "-5", str(cat_wallet_id): "1"}, fee=uint64(1)), tx_config=wallet_environments.tx_config, ) - all_offers = await env_1.rpc_client.get_all_offers() + all_offers = (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records assert len(all_offers) == 2 offer_count = await env_1.rpc_client.get_offers_count() assert offer_count.total == 2 @@ -1728,25 +1729,35 @@ def only_ids(trades: list[TradeRecord]) -> list[bytes32]: return [t.trade_id for t in trades] trade_record = (await env_1.rpc_client.get_offer(GetOffer(offer.name()))).trade_record - all_offers = await env_1.rpc_client.get_all_offers(include_completed=True) # confirmed at index descending + all_offers = ( # confirmed at index descending + await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True)) + ).trade_records assert len(all_offers) == 2 assert only_ids(all_offers) == only_ids([trade_record, new_trade_record]) - all_offers = await env_1.rpc_client.get_all_offers( - include_completed=True, reverse=True - ) # confirmed at index ascending + all_offers = ( # confirmed at index ascending + await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, reverse=True)) + ).trade_records assert only_ids(all_offers) == only_ids([new_trade_record, trade_record]) - all_offers = await env_1.rpc_client.get_all_offers(include_completed=True, sort_key="RELEVANCE") # most relevant + all_offers = ( # most relevant + await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, sort_key="RELEVANCE")) + ).trade_records assert only_ids(all_offers) == only_ids([new_trade_record, trade_record]) - all_offers = await env_1.rpc_client.get_all_offers( - include_completed=True, sort_key="RELEVANCE", reverse=True - ) # least relevant + all_offers = ( # least relevant + await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, sort_key="RELEVANCE", reverse=True)) + ).trade_records assert only_ids(all_offers) == only_ids([trade_record, new_trade_record]) # Test pagination - all_offers = await env_1.rpc_client.get_all_offers(include_completed=True, start=0, end=1) + all_offers = ( + await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, start=uint16(0), end=uint16(1))) + ).trade_records assert len(all_offers) == 1 - all_offers = await env_1.rpc_client.get_all_offers(include_completed=True, start=50) + all_offers = ( + await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, start=uint16(10))) + ).trade_records assert len(all_offers) == 0 - all_offers = await env_1.rpc_client.get_all_offers(include_completed=True, start=0, end=50) + all_offers = ( + await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, start=uint16(0), end=uint16(50))) + ).trade_records assert len(all_offers) == 2 await env_1.rpc_client.create_offer_for_ids( @@ -1754,11 +1765,25 @@ def only_ids(trades: list[TradeRecord]) -> list[bytes32]: tx_config=wallet_environments.tx_config, ) assert ( - len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 2 + len( + [ + o + for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records + if o.status == TradeStatus.PENDING_ACCEPT.value + ] + ) + == 2 ) await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, batch_size=1) assert ( - len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 0 + len( + [ + o + for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records + if o.status == TradeStatus.PENDING_ACCEPT.value + ] + ) + == 0 ) await wallet_environments.process_pending_states( [ @@ -1795,11 +1820,25 @@ def only_ids(trades: list[TradeRecord]) -> list[bytes32]: tx_config=wallet_environments.tx_config, ) assert ( - len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 2 + len( + [ + o + for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records + if o.status == TradeStatus.PENDING_ACCEPT.value + ] + ) + == 2 ) await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, cancel_all=True) assert ( - len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 0 + len( + [ + o + for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records + if o.status == TradeStatus.PENDING_ACCEPT.value + ] + ) + == 0 ) await wallet_environments.process_pending_states( @@ -1843,15 +1882,36 @@ def only_ids(trades: list[TradeRecord]) -> list[bytes32]: tx_config=wallet_environments.tx_config, ) assert ( - len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 1 + len( + [ + o + for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records + if o.status == TradeStatus.PENDING_ACCEPT.value + ] + ) + == 1 ) await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, asset_id=bytes32.zeros) assert ( - len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 1 + len( + [ + o + for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records + if o.status == TradeStatus.PENDING_ACCEPT.value + ] + ) + == 1 ) await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, asset_id=cat_asset_id) assert ( - len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 0 + len( + [ + o + for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records + if o.status == TradeStatus.PENDING_ACCEPT.value + ] + ) + == 0 ) with pytest.raises(ValueError, match="not currently supported"): diff --git a/chia/cmds/wallet_funcs.py b/chia/cmds/wallet_funcs.py index a8c3f7559990..bc8f0ffe2649 100644 --- a/chia/cmds/wallet_funcs.py +++ b/chia/cmds/wallet_funcs.py @@ -64,6 +64,7 @@ DIDUpdateMetadata, ExtendDerivationIndex, FungibleAsset, + GetAllOffers, GetNextAddress, GetNotifications, GetOffer, @@ -769,16 +770,20 @@ async def get_offers( # Traverse offers page by page while True: - new_records: list[TradeRecord] = await wallet_client.get_all_offers( - start, - end, - sort_key="RELEVANCE" if sort_by_relevance else "CONFIRMED_AT_HEIGHT", - reverse=reverse, - file_contents=file_contents, - exclude_my_offers=exclude_my_offers, - exclude_taken_offers=exclude_taken_offers, - include_completed=include_completed, - ) + new_records: list[TradeRecord] = ( + await wallet_client.get_all_offers( + GetAllOffers( + start=uint16(start), + end=uint16(end), + sort_key="RELEVANCE" if sort_by_relevance else "CONFIRMED_AT_HEIGHT", + reverse=reverse, + file_contents=file_contents, + exclude_my_offers=exclude_my_offers, + exclude_taken_offers=exclude_taken_offers, + include_completed=include_completed, + ) + ) + ).trade_records records.extend(new_records) # If fewer records were returned than requested, we're done diff --git a/chia/wallet/wallet_request_types.py b/chia/wallet/wallet_request_types.py index 728bec281adb..b7c3f3c9ece8 100644 --- a/chia/wallet/wallet_request_types.py +++ b/chia/wallet/wallet_request_types.py @@ -1982,6 +1982,42 @@ def from_json_dict(cls, json_dict: dict[str, Any]) -> Self: ) +@streamable +@dataclass(frozen=True) +class GetAllOffers(Streamable): + start: uint16 = uint16(0) + end: uint16 = uint16(10) + exclude_my_offers: bool = False + exclude_taken_offers: bool = False + include_completed: bool = False + sort_key: Optional[str] = None + reverse: bool = False + file_contents: bool = False + + +@streamable +@dataclass(frozen=True) +class GetAllOffersResponse(Streamable): + offers: Optional[list[str]] + trade_records: list[TradeRecord] + + def to_json_dict(self) -> dict[str, Any]: + return {**super().to_json_dict(), "trade_records": [tr.to_json_dict_convenience() for tr in self.trade_records]} + + @classmethod + def from_json_dict(cls, json_dict: dict[str, Any]) -> Self: + return cls( + offers=json_dict["offers"], + trade_records=[ + TradeRecord.from_json_dict_convenience( + json_tr, + bytes(Offer.from_bech32(json_dict["offers"][i])).hex() if json_dict["offers"] is not None else "", + ) + for i, json_tr in enumerate(json_dict["trade_records"]) + ], + ) + + @streamable @dataclass(frozen=True) class CancelOfferResponse(TransactionEndpointResponse): diff --git a/chia/wallet/wallet_rpc_api.py b/chia/wallet/wallet_rpc_api.py index e332240907d6..727ab9c593b0 100644 --- a/chia/wallet/wallet_rpc_api.py +++ b/chia/wallet/wallet_rpc_api.py @@ -187,6 +187,8 @@ GatherSigningInfo, GatherSigningInfoResponse, GenerateMnemonicResponse, + GetAllOffers, + GetAllOffersResponse, GetCATListResponse, GetCoinRecordsByNames, GetCoinRecordsByNamesResponse, @@ -2372,36 +2374,32 @@ async def get_offer(self, request: GetOffer) -> GetOfferResponse: trade_record, ) - async def get_all_offers(self, request: dict[str, Any]) -> EndpointResult: + @marshal + async def get_all_offers(self, request: GetAllOffers) -> GetAllOffersResponse: trade_mgr = self.service.wallet_state_manager.trade_manager - start: int = request.get("start", 0) - end: int = request.get("end", 10) - exclude_my_offers: bool = request.get("exclude_my_offers", False) - exclude_taken_offers: bool = request.get("exclude_taken_offers", False) - include_completed: bool = request.get("include_completed", False) - sort_key: Optional[str] = request.get("sort_key", None) - reverse: bool = request.get("reverse", False) - file_contents: bool = request.get("file_contents", False) - all_trades = await trade_mgr.trade_store.get_trades_between( - start, - end, - sort_key=sort_key, - reverse=reverse, - exclude_my_offers=exclude_my_offers, - exclude_taken_offers=exclude_taken_offers, - include_completed=include_completed, + request.start, + request.end, + sort_key=request.sort_key, + reverse=request.reverse, + exclude_my_offers=request.exclude_my_offers, + exclude_taken_offers=request.exclude_taken_offers, + include_completed=request.include_completed, ) result = [] - offer_values: Optional[list[str]] = [] if file_contents else None + offer_values: Optional[list[str]] = [] if request.file_contents else None for trade in all_trades: - result.append(trade.to_json_dict_convenience()) - if file_contents and offer_values is not None: + result.append(trade) + if request.file_contents: offer_to_return: bytes = trade.offer if trade.taken_offer is None else trade.taken_offer - offer_values.append(Offer.from_bytes(offer_to_return).to_bech32()) + # semantics guarantee this to be not None + offer_values.append(Offer.from_bytes(offer_to_return).to_bech32()) # type: ignore[union-attr] - return {"trade_records": result, "offers": offer_values} + return GetAllOffersResponse( + trade_records=result, + offers=offer_values, + ) @marshal async def get_offers_count(self, request: Empty) -> GetOffersCountResponse: diff --git a/chia/wallet/wallet_rpc_client.py b/chia/wallet/wallet_rpc_client.py index 499c1cc24186..6fea02cd2079 100644 --- a/chia/wallet/wallet_rpc_client.py +++ b/chia/wallet/wallet_rpc_client.py @@ -10,8 +10,6 @@ from chia.types.blockchain_format.coin import Coin from chia.wallet.conditions import Condition, ConditionValidTimes, conditions_to_json_dicts from chia.wallet.puzzles.clawback.metadata import AutoClaimSettings -from chia.wallet.trade_record import TradeRecord -from chia.wallet.trading.offer import Offer from chia.wallet.transaction_record import TransactionRecord from chia.wallet.util.clvm_streamable import json_deserialize_with_clvm_streamable from chia.wallet.util.tx_config import TXConfig @@ -97,6 +95,8 @@ GatherSigningInfo, GatherSigningInfoResponse, GenerateMnemonicResponse, + GetAllOffers, + GetAllOffersResponse, GetCATListResponse, GetCoinRecordsByNames, GetCoinRecordsByNamesResponse, @@ -707,40 +707,8 @@ async def take_offer( async def get_offer(self, request: GetOffer) -> GetOfferResponse: return GetOfferResponse.from_json_dict(await self.fetch("get_offer", request.to_json_dict())) - async def get_all_offers( - self, - start: int = 0, - end: int = 50, - sort_key: Optional[str] = None, - reverse: bool = False, - file_contents: bool = False, - exclude_my_offers: bool = False, - exclude_taken_offers: bool = False, - include_completed: bool = False, - ) -> list[TradeRecord]: - res = await self.fetch( - "get_all_offers", - { - "start": start, - "end": end, - "sort_key": sort_key, - "reverse": reverse, - "file_contents": file_contents, - "exclude_my_offers": exclude_my_offers, - "exclude_taken_offers": exclude_taken_offers, - "include_completed": include_completed, - }, - ) - - records = [] - if file_contents: - optional_offers = [bytes(Offer.from_bech32(o)).hex() for o in res["offers"]] - else: - optional_offers = [""] * len(res["trade_records"]) - for record, offer in zip(res["trade_records"], optional_offers): - records.append(TradeRecord.from_json_dict_convenience(record, offer)) - - return records + async def get_all_offers(self, request: GetAllOffers) -> GetAllOffersResponse: + return GetAllOffersResponse.from_json_dict(await self.fetch("get_all_offers", request.to_json_dict())) async def get_offers_count(self) -> GetOffersCountResponse: return GetOffersCountResponse.from_json_dict(await self.fetch("get_offers_count", {}))