diff --git a/chia/_tests/cmds/wallet/test_wallet.py b/chia/_tests/cmds/wallet/test_wallet.py index 38e7a4602809..37b410ef9650 100644 --- a/chia/_tests/cmds/wallet/test_wallet.py +++ b/chia/_tests/cmds/wallet/test_wallet.py @@ -42,6 +42,7 @@ from chia.wallet.wallet_coin_store import GetCoinRecords from chia.wallet.wallet_request_types import ( BalanceResponse, + CancelOffer, CancelOfferResponse, CATAssetIDToName, CATAssetIDToNameResponse, @@ -1180,14 +1181,14 @@ async def get_offer(self, request: GetOffer) -> GetOfferResponse: async def cancel_offer( self, - trade_id: bytes32, + request: CancelOffer, tx_config: TXConfig, - fee: uint64 = uint64(0), - secure: bool = True, - push: bool = True, + extra_conditions: tuple[Condition, ...] = tuple(), timelock_info: ConditionValidTimes = ConditionValidTimes(), ) -> CancelOfferResponse: - self.add_to_log("cancel_offer", (trade_id, tx_config, fee, secure, push, timelock_info)) + self.add_to_log( + "cancel_offer", (request.trade_id, tx_config, request.fee, request.secure, request.push, timelock_info) + ) return CancelOfferResponse([STD_UTX], [STD_TX]) inst_rpc_client = CancelOfferRpcClient() diff --git a/chia/_tests/wallet/rpc/test_wallet_rpc.py b/chia/_tests/wallet/rpc/test_wallet_rpc.py index e83238e0c502..febaaa715f2d 100644 --- a/chia/_tests/wallet/rpc/test_wallet_rpc.py +++ b/chia/_tests/wallet/rpc/test_wallet_rpc.py @@ -106,6 +106,8 @@ from chia.wallet.wallet_protocol import WalletProtocol from chia.wallet.wallet_request_types import ( AddKey, + CancelOffer, + CancelOffers, CATAssetIDToName, CATGetAssetID, CATGetName, @@ -1622,14 +1624,22 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_ ).trade_record assert TradeStatus(trade_record.status) == TradeStatus.PENDING_CONFIRM - await env_1.rpc_client.cancel_offer(offer.name(), wallet_environments.tx_config, secure=False) + await env_1.rpc_client.cancel_offer( + CancelOffer( + trade_id=offer.name(), + secure=False, + push=True, + ), + tx_config=wallet_environments.tx_config, + ) trade_record = (await env_1.rpc_client.get_offer(GetOffer(offer.name(), file_contents=True))).trade_record assert trade_record.offer == bytes(offer) assert TradeStatus(trade_record.status) == TradeStatus.CANCELLED failed_cancel_res = await env_1.rpc_client.cancel_offer( - offer.name(), wallet_environments.tx_config, fee=uint64(1), secure=True + CancelOffer(trade_id=offer.name(), fee=uint64(1), secure=True, push=True), + tx_config=wallet_environments.tx_config, ) trade_record = (await env_1.rpc_client.get_offer(GetOffer(offer.name()))).trade_record @@ -1774,7 +1784,9 @@ def only_ids(trades: list[TradeRecord]) -> list[bytes32]: ) == 2 ) - await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, batch_size=1) + await env_1.rpc_client.cancel_offers( + CancelOffers(secure=True, batch_size=uint16(1), push=True), tx_config=wallet_environments.tx_config + ) assert ( len( [ @@ -1829,7 +1841,9 @@ def only_ids(trades: list[TradeRecord]) -> list[bytes32]: ) == 2 ) - await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, cancel_all=True) + await env_1.rpc_client.cancel_offers( + CancelOffers(secure=True, cancel_all=True, push=True), tx_config=wallet_environments.tx_config + ) assert ( len( [ @@ -1891,7 +1905,9 @@ def only_ids(trades: list[TradeRecord]) -> list[bytes32]: ) == 1 ) - await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, asset_id=bytes32.zeros) + await env_1.rpc_client.cancel_offers( + CancelOffers(secure=True, asset_id=bytes32.zeros.hex(), push=True), tx_config=wallet_environments.tx_config + ) assert ( len( [ @@ -1902,7 +1918,9 @@ def only_ids(trades: list[TradeRecord]) -> list[bytes32]: ) == 1 ) - await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, asset_id=cat_asset_id) + await env_1.rpc_client.cancel_offers( + CancelOffers(secure=True, asset_id=cat_asset_id.hex(), push=True), tx_config=wallet_environments.tx_config + ) assert ( len( [ diff --git a/chia/cmds/wallet_funcs.py b/chia/cmds/wallet_funcs.py index bc8f0ffe2649..437dfb420336 100644 --- a/chia/cmds/wallet_funcs.py +++ b/chia/cmds/wallet_funcs.py @@ -45,6 +45,7 @@ from chia.wallet.vc_wallet.vc_store import VCProofs from chia.wallet.wallet_coin_store import GetCoinRecords from chia.wallet.wallet_request_types import ( + CancelOffer, CATAssetIDToName, CATAssetIDToNameResponse, CATGetName, @@ -930,11 +931,8 @@ async def cancel_offer( cli_confirm(f"Are you sure you wish to cancel offer with ID: {trade_record.trade_id}? (y/n): ") res = await wallet_client.cancel_offer( - offer_id, - CMDTXConfigLoader().to_tx_config(units["chia"], config, fingerprint), - secure=secure, - fee=fee, - push=push, + CancelOffer(trade_id=offer_id, secure=secure, fee=fee, push=push), + tx_config=CMDTXConfigLoader().to_tx_config(units["chia"], config, fingerprint), timelock_info=condition_valid_times, ) if push or not secure: diff --git a/chia/data_layer/data_layer.py b/chia/data_layer/data_layer.py index b5a018682587..b8b0826fdacf 100644 --- a/chia/data_layer/data_layer.py +++ b/chia/data_layer/data_layer.py @@ -69,6 +69,7 @@ from chia.wallet.transaction_record import TransactionRecord from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG from chia.wallet.wallet_request_types import ( + CancelOffer, CreateNewDL, CreateOfferForIDs, DLDeleteMirror, @@ -1283,9 +1284,7 @@ async def cancel_offer(self, trade_id: bytes32, secure: bool, fee: uint64) -> No store_ids = [offered.launcher_id for offered in summary.offered] await self.wallet_rpc.cancel_offer( - trade_id=trade_id, - secure=secure, - fee=fee, + CancelOffer(trade_id=trade_id, secure=secure, fee=fee, push=True), # TODO: probably shouldn't be default but due to peculiarities in the RPC, we're using a stop gap. # This is not a change in behavior, the default was already implicit. tx_config=DEFAULT_TX_CONFIG, diff --git a/chia/wallet/wallet_request_types.py b/chia/wallet/wallet_request_types.py index b7c3f3c9ece8..06dc4f74954e 100644 --- a/chia/wallet/wallet_request_types.py +++ b/chia/wallet/wallet_request_types.py @@ -2018,12 +2018,29 @@ def from_json_dict(cls, json_dict: dict[str, Any]) -> Self: ) +@streamable +@dataclass(frozen=True) +class CancelOffer(TransactionEndpointRequest): + trade_id: bytes32 = field(default_factory=default_raise) + secure: bool = field(default_factory=default_raise) + + @streamable @dataclass(frozen=True) class CancelOfferResponse(TransactionEndpointResponse): pass +@streamable +@dataclass(frozen=True) +class CancelOffers(TransactionEndpointRequest): + secure: bool = field(default_factory=default_raise) + batch_fee: uint64 = uint64(0) + batch_size: uint16 = uint16(5) + cancel_all: bool = False + asset_id: str = "xch" + + @streamable @dataclass(frozen=True) class CancelOffersResponse(TransactionEndpointResponse): diff --git a/chia/wallet/wallet_rpc_api.py b/chia/wallet/wallet_rpc_api.py index 727ab9c593b0..15bcf7ddf8ed 100644 --- a/chia/wallet/wallet_rpc_api.py +++ b/chia/wallet/wallet_rpc_api.py @@ -112,6 +112,10 @@ ApplySignatures, ApplySignaturesResponse, BalanceResponse, + CancelOffer, + CancelOfferResponse, + CancelOffers, + CancelOffersResponse, CATAssetIDToName, CATAssetIDToNameResponse, CATGetAssetID, @@ -2412,41 +2416,40 @@ async def get_offers_count(self, request: Empty) -> GetOffersCountResponse: ) @tx_endpoint(push=True) + @marshal async def cancel_offer( self, - request: dict[str, Any], + request: CancelOffer, action_scope: WalletActionScope, extra_conditions: tuple[Condition, ...] = tuple(), - ) -> EndpointResult: + ) -> CancelOfferResponse: wsm = self.service.wallet_state_manager - secure = request["secure"] - trade_id = bytes32.from_hexstr(request["trade_id"]) - fee: uint64 = uint64(request.get("fee", 0)) async with self.service.wallet_state_manager.lock: await wsm.trade_manager.cancel_pending_offers( - [trade_id], action_scope, fee=fee, secure=secure, extra_conditions=extra_conditions + [request.trade_id], + action_scope, + fee=request.fee, + secure=request.secure, + extra_conditions=extra_conditions, ) - return {"transactions": None} # tx_endpoint wrapper will take care of this + return CancelOfferResponse([], []) # tx_endpoint will fill in default values here @tx_endpoint(push=True, merge_spends=False) + @marshal async def cancel_offers( self, - request: dict[str, Any], + request: CancelOffers, action_scope: WalletActionScope, extra_conditions: tuple[Condition, ...] = tuple(), - ) -> EndpointResult: - secure = request["secure"] - batch_fee: uint64 = uint64(request.get("batch_fee", 0)) - batch_size = request.get("batch_size", 5) - cancel_all = request.get("cancel_all", False) - if cancel_all: - asset_id = None + ) -> CancelOffersResponse: + if request.cancel_all: + asset_id: Optional[str] = None else: - asset_id = request.get("asset_id", "xch") + asset_id = request.asset_id start: int = 0 - end: int = start + batch_size + end: int = start + request.batch_size trade_mgr = self.service.wallet_state_manager.trade_manager log.info(f"Start cancelling offers for {'asset_id: ' + asset_id if asset_id is not None else 'all'} ...") # Traverse offers page by page @@ -2464,7 +2467,7 @@ async def cancel_offers( include_completed=False, ) for trade in trades: - if cancel_all: + if request.cancel_all: records[trade.trade_id] = trade continue if trade.offer and trade.offer != b"": @@ -2480,20 +2483,20 @@ async def cancel_offers( await trade_mgr.cancel_pending_offers( list(records.keys()), action_scope, - batch_fee, - secure, + request.batch_fee, + request.secure, records, extra_conditions=extra_conditions, ) log.info(f"Cancelled offers {start} to {end} ...") # If fewer records were returned than requested, we're done - if len(trades) < batch_size: + if len(trades) < request.batch_size: break start = end - end += batch_size + end += request.batch_size - return {"transactions": None} # tx_endpoint wrapper will take care of this + return CancelOffersResponse([], []) # tx_endpoint wrapper will take care of this ########################################################################################## # Distributed Identities diff --git a/chia/wallet/wallet_rpc_client.py b/chia/wallet/wallet_rpc_client.py index 6fea02cd2079..4d10114d7c18 100644 --- a/chia/wallet/wallet_rpc_client.py +++ b/chia/wallet/wallet_rpc_client.py @@ -19,7 +19,9 @@ AddKeyResponse, ApplySignatures, ApplySignaturesResponse, + CancelOffer, CancelOfferResponse, + CancelOffers, CancelOffersResponse, CATAssetIDToName, CATAssetIDToNameResponse, @@ -715,58 +717,30 @@ async def get_offers_count(self) -> GetOffersCountResponse: async def cancel_offer( self, - trade_id: bytes32, + request: CancelOffer, tx_config: TXConfig, - fee: int = 0, - secure: bool = True, extra_conditions: tuple[Condition, ...] = tuple(), timelock_info: ConditionValidTimes = ConditionValidTimes(), - push: bool = True, ) -> CancelOfferResponse: - res = await self.fetch( - "cancel_offer", - { - "trade_id": trade_id.hex(), - "secure": secure, - "fee": fee, - "extra_conditions": conditions_to_json_dicts(extra_conditions), - "push": push, - **tx_config.to_json_dict(), - **timelock_info.to_json_dict(), - }, + return CancelOfferResponse.from_json_dict( + await self.fetch( + "cancel_offer", request.json_serialize_for_transport(tx_config, extra_conditions, timelock_info) + ) ) - return json_deserialize_with_clvm_streamable(res, CancelOfferResponse) - async def cancel_offers( self, + request: CancelOffers, tx_config: TXConfig, - batch_fee: int = 0, - secure: bool = True, - batch_size: int = 5, - cancel_all: bool = False, - asset_id: Optional[bytes32] = None, extra_conditions: tuple[Condition, ...] = tuple(), timelock_info: ConditionValidTimes = ConditionValidTimes(), - push: bool = True, ) -> CancelOffersResponse: - res = await self.fetch( - "cancel_offers", - { - "secure": secure, - "batch_fee": batch_fee, - "batch_size": batch_size, - "cancel_all": cancel_all, - "asset_id": None if asset_id is None else asset_id.hex(), - "extra_conditions": conditions_to_json_dicts(extra_conditions), - "push": push, - **tx_config.to_json_dict(), - **timelock_info.to_json_dict(), - }, + return CancelOffersResponse.from_json_dict( + await self.fetch( + "cancel_offers", request.json_serialize_for_transport(tx_config, extra_conditions, timelock_info) + ) ) - return json_deserialize_with_clvm_streamable(res, CancelOffersResponse) - async def get_cat_list(self) -> GetCATListResponse: return GetCATListResponse.from_json_dict(await self.fetch("get_cat_list", {}))