Skip to content

Commit 53cebb1

Browse files
committed
Port get_offer to @marshal
1 parent 9c3ef2f commit 53cebb1

File tree

7 files changed

+74
-33
lines changed

7 files changed

+74
-33
lines changed

chia/_tests/cmds/wallet/test_wallet.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
GetHeightInfoResponse,
6161
GetNextAddress,
6262
GetNextAddressResponse,
63+
GetOffer,
64+
GetOfferResponse,
6365
GetTransaction,
6466
GetTransactions,
6567
GetTransactionsResponse,
@@ -1163,22 +1165,25 @@ def test_cancel_offer(capsys: object, get_test_cli_clients: tuple[TestRpcClients
11631165

11641166
# set RPC Client
11651167
class CancelOfferRpcClient(TestWalletRpcClient):
1166-
async def get_offer(self, trade_id: bytes32, file_contents: bool = False) -> TradeRecord:
1167-
self.add_to_log("get_offer", (trade_id, file_contents))
1168+
async def get_offer(self, request: GetOffer) -> GetOfferResponse:
1169+
self.add_to_log("get_offer", (request.trade_id, request.file_contents))
11681170
offer = Offer.from_bech32(test_offer_file_bech32)
1169-
return TradeRecord(
1170-
confirmed_at_index=uint32(0),
1171-
accepted_at_time=uint64(0),
1172-
created_at_time=uint64(12345678),
1173-
is_my_offer=True,
1174-
sent=uint32(0),
1175-
sent_to=[],
1176-
offer=bytes(offer),
1177-
taken_offer=None,
1178-
coins_of_interest=offer.get_involved_coins(),
1179-
trade_id=offer.name(),
1180-
status=uint32(TradeStatus.PENDING_ACCEPT.value),
1181-
valid_times=ConditionValidTimes(),
1171+
return GetOfferResponse(
1172+
test_offer_file_bech32,
1173+
TradeRecord(
1174+
confirmed_at_index=uint32(0),
1175+
accepted_at_time=uint64(0),
1176+
created_at_time=uint64(12345678),
1177+
is_my_offer=True,
1178+
sent=uint32(0),
1179+
sent_to=[],
1180+
offer=bytes(offer),
1181+
taken_offer=None,
1182+
coins_of_interest=offer.get_involved_coins(),
1183+
trade_id=offer.name(),
1184+
status=uint32(TradeStatus.PENDING_ACCEPT.value),
1185+
valid_times=ConditionValidTimes(),
1186+
),
11821187
)
11831188

11841189
async def cancel_offer(

chia/_tests/wallet/rpc/test_wallet_rpc.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
GetCoinRecordsByNames,
134134
GetNextAddress,
135135
GetNotifications,
136+
GetOffer,
136137
GetOfferSummary,
137138
GetPrivateKey,
138139
GetSpendableCoins,
@@ -1622,15 +1623,15 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_
16221623

16231624
await env_1.rpc_client.cancel_offer(offer.name(), wallet_environments.tx_config, secure=False)
16241625

1625-
trade_record = await env_1.rpc_client.get_offer(offer.name(), file_contents=True)
1626+
trade_record = (await env_1.rpc_client.get_offer(GetOffer(offer.name(), file_contents=True))).trade_record
16261627
assert trade_record.offer == bytes(offer)
16271628
assert TradeStatus(trade_record.status) == TradeStatus.CANCELLED
16281629

16291630
failed_cancel_res = await env_1.rpc_client.cancel_offer(
16301631
offer.name(), wallet_environments.tx_config, fee=uint64(1), secure=True
16311632
)
16321633

1633-
trade_record = await env_1.rpc_client.get_offer(offer.name())
1634+
trade_record = (await env_1.rpc_client.get_offer(GetOffer(offer.name()))).trade_record
16341635
assert TradeStatus(trade_record.status) == TradeStatus.PENDING_CANCEL
16351636

16361637
create_res = await env_1.rpc_client.create_offer_for_ids(
@@ -1717,7 +1718,7 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_
17171718
)
17181719

17191720
async def is_trade_confirmed(client: WalletRpcClient, offer: Offer) -> bool:
1720-
trade_record = await client.get_offer(offer.name())
1721+
trade_record = (await client.get_offer(GetOffer(offer.name()))).trade_record
17211722
return TradeStatus(trade_record.status) == TradeStatus.CONFIRMED
17221723

17231724
await time_out_assert(15, is_trade_confirmed, True, env_1.rpc_client, offer)
@@ -1726,7 +1727,7 @@ async def is_trade_confirmed(client: WalletRpcClient, offer: Offer) -> bool:
17261727
def only_ids(trades: list[TradeRecord]) -> list[bytes32]:
17271728
return [t.trade_id for t in trades]
17281729

1729-
trade_record = await env_1.rpc_client.get_offer(offer.name())
1730+
trade_record = (await env_1.rpc_client.get_offer(GetOffer(offer.name()))).trade_record
17301731
all_offers = await env_1.rpc_client.get_all_offers(include_completed=True) # confirmed at index descending
17311732
assert len(all_offers) == 2
17321733
assert only_ids(all_offers) == only_ids([trade_record, new_trade_record])

chia/cmds/wallet_funcs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
FungibleAsset,
6767
GetNextAddress,
6868
GetNotifications,
69+
GetOffer,
6970
GetTransaction,
7071
GetTransactions,
7172
GetWalletBalance,
@@ -787,7 +788,7 @@ async def get_offers(
787788
start = end
788789
end += batch_size
789790
else:
790-
records = [await wallet_client.get_offer(offer_id, file_contents)]
791+
records = [(await wallet_client.get_offer(GetOffer(offer_id, file_contents))).trade_record]
791792
if filepath is not None:
792793
with open(pathlib.Path(filepath), "w") as file:
793794
file.write(Offer.from_bytes(records[0].offer).to_bech32())
@@ -919,7 +920,7 @@ async def cancel_offer(
919920
condition_valid_times: ConditionValidTimes,
920921
) -> list[TransactionRecord]:
921922
async with get_wallet_client(root_path, wallet_rpc_port, fp) as (wallet_client, fingerprint, config):
922-
trade_record = await wallet_client.get_offer(offer_id, file_contents=True)
923+
trade_record = (await wallet_client.get_offer(GetOffer(offer_id, file_contents=True))).trade_record
923924
await print_trade_record(trade_record, wallet_client, summaries=True)
924925

925926
cli_confirm(f"Are you sure you wish to cancel offer with ID: {trade_record.trade_id}? (y/n): ")

chia/data_layer/data_layer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
DLUpdateMultiple,
8282
DLUpdateMultipleUpdates,
8383
DLUpdateRoot,
84+
GetOffer,
8485
LauncherRootPair,
8586
LogIn,
8687
TakeOffer,
@@ -1274,7 +1275,9 @@ async def cancel_offer(self, trade_id: bytes32, secure: bool, fee: uint64) -> No
12741275
store_ids: list[bytes32] = []
12751276

12761277
if not secure:
1277-
trade_record = await self.wallet_rpc.get_offer(trade_id=trade_id, file_contents=True)
1278+
trade_record = (
1279+
await self.wallet_rpc.get_offer(GetOffer(trade_id=trade_id, file_contents=True))
1280+
).trade_record
12781281
trading_offer = TradingOffer.from_bytes(trade_record.offer)
12791282
summary = await DataLayerWallet.get_offer_summary(offer=trading_offer)
12801283
store_ids = [offered.launcher_id for offered in summary.offered]

chia/wallet/wallet_request_types.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1955,6 +1955,33 @@ class TakeOfferResponse(_OfferEndpointResponse): # Inheriting for de-dup sake
19551955
pass
19561956

19571957

1958+
@streamable
1959+
@dataclass(frozen=True)
1960+
class GetOffer(Streamable):
1961+
trade_id: bytes32
1962+
file_contents: bool = False
1963+
1964+
1965+
@streamable
1966+
@dataclass(frozen=True)
1967+
class GetOfferResponse(Streamable):
1968+
offer: Optional[str]
1969+
trade_record: TradeRecord
1970+
1971+
def to_json_dict(self) -> dict[str, Any]:
1972+
return {**super().to_json_dict(), "trade_record": self.trade_record.to_json_dict_convenience()}
1973+
1974+
@classmethod
1975+
def from_json_dict(cls, json_dict: dict[str, Any]) -> Self:
1976+
return cls(
1977+
offer=json_dict["offer"],
1978+
trade_record=TradeRecord.from_json_dict_convenience(
1979+
json_dict["trade_record"],
1980+
bytes(Offer.from_bech32(json_dict["offer"])).hex() if json_dict["offer"] is not None else "",
1981+
),
1982+
)
1983+
1984+
19581985
@streamable
19591986
@dataclass(frozen=True)
19601987
class CancelOfferResponse(TransactionEndpointResponse):

chia/wallet/wallet_rpc_api.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@
197197
GetNextAddressResponse,
198198
GetNotifications,
199199
GetNotificationsResponse,
200+
GetOffer,
201+
GetOfferResponse,
200202
GetOffersCountResponse,
201203
GetOfferSummary,
202204
GetOfferSummaryResponse,
@@ -2355,18 +2357,20 @@ async def take_offer(
23552357
trade_record,
23562358
)
23572359

2358-
async def get_offer(self, request: dict[str, Any]) -> EndpointResult:
2360+
@marshal
2361+
async def get_offer(self, request: GetOffer) -> GetOfferResponse:
23592362
trade_mgr = self.service.wallet_state_manager.trade_manager
23602363

2361-
trade_id = bytes32.from_hexstr(request["trade_id"])
2362-
file_contents: bool = request.get("file_contents", False)
2363-
trade_record: Optional[TradeRecord] = await trade_mgr.get_trade_by_id(bytes32(trade_id))
2364+
trade_record: Optional[TradeRecord] = await trade_mgr.get_trade_by_id(request.trade_id)
23642365
if trade_record is None:
2365-
raise ValueError(f"No trade with trade id: {trade_id.hex()}")
2366+
raise ValueError(f"No trade with trade id: {request.trade_id.hex()}")
23662367

23672368
offer_to_return: bytes = trade_record.offer if trade_record.taken_offer is None else trade_record.taken_offer
2368-
offer_value: Optional[str] = Offer.from_bytes(offer_to_return).to_bech32() if file_contents else None
2369-
return {"trade_record": trade_record.to_json_dict_convenience(), "offer": offer_value}
2369+
offer: Optional[str] = Offer.from_bytes(offer_to_return).to_bech32() if request.file_contents else None
2370+
return GetOfferResponse(
2371+
offer,
2372+
trade_record,
2373+
)
23702374

23712375
async def get_all_offers(self, request: dict[str, Any]) -> EndpointResult:
23722376
trade_mgr = self.service.wallet_state_manager.trade_manager

chia/wallet/wallet_rpc_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@
107107
GetNextAddressResponse,
108108
GetNotifications,
109109
GetNotificationsResponse,
110+
GetOffer,
111+
GetOfferResponse,
110112
GetOffersCountResponse,
111113
GetOfferSummary,
112114
GetOfferSummaryResponse,
@@ -702,10 +704,8 @@ async def take_offer(
702704
)
703705
)
704706

705-
async def get_offer(self, trade_id: bytes32, file_contents: bool = False) -> TradeRecord:
706-
res = await self.fetch("get_offer", {"trade_id": trade_id.hex(), "file_contents": file_contents})
707-
offer_str = bytes(Offer.from_bech32(res["offer"])).hex() if file_contents else ""
708-
return TradeRecord.from_json_dict_convenience(res["trade_record"], offer_str)
707+
async def get_offer(self, request: GetOffer) -> GetOfferResponse:
708+
return GetOfferResponse.from_json_dict(await self.fetch("get_offer", request.to_json_dict()))
709709

710710
async def get_all_offers(
711711
self,

0 commit comments

Comments
 (0)