Skip to content

Commit daf6858

Browse files
committed
Port get_all_offers to @marshal
1 parent e95dba8 commit daf6858

File tree

6 files changed

+169
-110
lines changed

6 files changed

+169
-110
lines changed

chia/_tests/cmds/wallet/test_wallet.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
ExtendDerivationIndex,
5757
ExtendDerivationIndexResponse,
5858
FungibleAsset,
59+
GetAllOffers,
60+
GetAllOffersResponse,
5961
GetCurrentDerivationIndexResponse,
6062
GetHeightInfoResponse,
6163
GetNextAddress,
@@ -948,32 +950,22 @@ def test_get_offers(capsys: object, get_test_cli_clients: tuple[TestRpcClients,
948950

949951
# set RPC Client
950952
class GetOffersRpcClient(TestWalletRpcClient):
951-
async def get_all_offers(
952-
self,
953-
start: int = 0,
954-
end: int = 50,
955-
sort_key: Optional[str] = None,
956-
reverse: bool = False,
957-
file_contents: bool = False,
958-
exclude_my_offers: bool = False,
959-
exclude_taken_offers: bool = False,
960-
include_completed: bool = False,
961-
) -> list[TradeRecord]:
953+
async def get_all_offers(self, request: GetAllOffers) -> GetAllOffersResponse:
962954
self.add_to_log(
963955
"get_all_offers",
964956
(
965-
start,
966-
end,
967-
sort_key,
968-
reverse,
969-
file_contents,
970-
exclude_my_offers,
971-
exclude_taken_offers,
972-
include_completed,
957+
request.start,
958+
request.end,
959+
request.sort_key,
960+
request.reverse,
961+
request.file_contents,
962+
request.exclude_my_offers,
963+
request.exclude_taken_offers,
964+
request.include_completed,
973965
),
974966
)
975967
records: list[TradeRecord] = []
976-
for i in reversed(range(start, end - 1)): # reversed to match the sort order
968+
for i in reversed(range(request.start, request.end - 1)): # reversed to match the sort order
977969
trade_offer = TradeRecord(
978970
confirmed_at_index=uint32(0),
979971
accepted_at_time=None,
@@ -997,7 +989,7 @@ async def get_all_offers(
997989
),
998990
)
999991
records.append(trade_offer)
1000-
return records
992+
return GetAllOffersResponse([], records)
1001993

1002994
inst_rpc_client = GetOffersRpcClient()
1003995
test_rpc_clients.wallet_rpc_client = inst_rpc_client

chia/_tests/wallet/rpc/test_wallet_rpc.py

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
DIDTransferDID,
132132
DIDUpdateMetadata,
133133
FungibleAsset,
134+
GetAllOffers,
134135
GetCoinRecordsByNames,
135136
GetNextAddress,
136137
GetNotifications,
@@ -1595,7 +1596,7 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_
15951596
CreateOfferForIDs(offer={str(1): "-5", cat_asset_id.hex(): "1"}, validate_only=True),
15961597
tx_config=wallet_environments.tx_config,
15971598
)
1598-
all_offers = await env_1.rpc_client.get_all_offers()
1599+
all_offers = (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
15991600
assert len(all_offers) == 0
16001601

16011602
driver_dict = {
@@ -1639,7 +1640,7 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_
16391640
assert offer_validity_response.id == offer.name()
16401641
assert offer_validity_response.valid
16411642

1642-
all_offers = await env_1.rpc_client.get_all_offers(file_contents=True)
1643+
all_offers = (await env_1.rpc_client.get_all_offers(GetAllOffers(file_contents=True))).trade_records
16431644
assert len(all_offers) == 1
16441645
assert TradeStatus(all_offers[0].status) == TradeStatus.PENDING_ACCEPT
16451646
assert all_offers[0].offer == bytes(offer)
@@ -1678,7 +1679,7 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_
16781679
CreateOfferForIDs(offer={str(1): "-5", str(cat_wallet_id): "1"}, fee=uint64(1)),
16791680
tx_config=wallet_environments.tx_config,
16801681
)
1681-
all_offers = await env_1.rpc_client.get_all_offers()
1682+
all_offers = (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
16821683
assert len(all_offers) == 2
16831684
offer_count = await env_1.rpc_client.get_offers_count()
16841685
assert offer_count.total == 2
@@ -1768,37 +1769,61 @@ def only_ids(trades: list[TradeRecord]) -> list[bytes32]:
17681769
return [t.trade_id for t in trades]
17691770

17701771
trade_record = (await env_1.rpc_client.get_offer(GetOffer(offer.name()))).trade_record
1771-
all_offers = await env_1.rpc_client.get_all_offers(include_completed=True) # confirmed at index descending
1772+
all_offers = ( # confirmed at index descending
1773+
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True))
1774+
).trade_records
17721775
assert len(all_offers) == 2
17731776
assert only_ids(all_offers) == only_ids([trade_record, new_trade_record])
1774-
all_offers = await env_1.rpc_client.get_all_offers(
1775-
include_completed=True, reverse=True
1776-
) # confirmed at index ascending
1777+
all_offers = ( # confirmed at index ascending
1778+
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, reverse=True))
1779+
).trade_records
17771780
assert only_ids(all_offers) == only_ids([new_trade_record, trade_record])
1778-
all_offers = await env_1.rpc_client.get_all_offers(include_completed=True, sort_key="RELEVANCE") # most relevant
1781+
all_offers = ( # most relevant
1782+
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, sort_key="RELEVANCE"))
1783+
).trade_records
17791784
assert only_ids(all_offers) == only_ids([new_trade_record, trade_record])
1780-
all_offers = await env_1.rpc_client.get_all_offers(
1781-
include_completed=True, sort_key="RELEVANCE", reverse=True
1782-
) # least relevant
1785+
all_offers = ( # least relevant
1786+
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, sort_key="RELEVANCE", reverse=True))
1787+
).trade_records
17831788
assert only_ids(all_offers) == only_ids([trade_record, new_trade_record])
17841789
# Test pagination
1785-
all_offers = await env_1.rpc_client.get_all_offers(include_completed=True, start=0, end=1)
1790+
all_offers = (
1791+
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, start=uint16(0), end=uint16(1)))
1792+
).trade_records
17861793
assert len(all_offers) == 1
1787-
all_offers = await env_1.rpc_client.get_all_offers(include_completed=True, start=50)
1794+
all_offers = (
1795+
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, start=uint16(10)))
1796+
).trade_records
17881797
assert len(all_offers) == 0
1789-
all_offers = await env_1.rpc_client.get_all_offers(include_completed=True, start=0, end=50)
1798+
all_offers = (
1799+
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, start=uint16(0), end=uint16(50)))
1800+
).trade_records
17901801
assert len(all_offers) == 2
17911802

17921803
await env_1.rpc_client.create_offer_for_ids(
17931804
CreateOfferForIDs(offer={str(1): "-5", cat_asset_id.hex(): "1"}, driver_dict=driver_dict),
17941805
tx_config=wallet_environments.tx_config,
17951806
)
17961807
assert (
1797-
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 2
1808+
len(
1809+
[
1810+
o
1811+
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
1812+
if o.status == TradeStatus.PENDING_ACCEPT.value
1813+
]
1814+
)
1815+
== 2
17981816
)
17991817
await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, batch_size=1)
18001818
assert (
1801-
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 0
1819+
len(
1820+
[
1821+
o
1822+
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
1823+
if o.status == TradeStatus.PENDING_ACCEPT.value
1824+
]
1825+
)
1826+
== 0
18021827
)
18031828
await wallet_environments.process_pending_states(
18041829
[
@@ -1835,11 +1860,25 @@ def only_ids(trades: list[TradeRecord]) -> list[bytes32]:
18351860
tx_config=wallet_environments.tx_config,
18361861
)
18371862
assert (
1838-
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 2
1863+
len(
1864+
[
1865+
o
1866+
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
1867+
if o.status == TradeStatus.PENDING_ACCEPT.value
1868+
]
1869+
)
1870+
== 2
18391871
)
18401872
await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, cancel_all=True)
18411873
assert (
1842-
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 0
1874+
len(
1875+
[
1876+
o
1877+
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
1878+
if o.status == TradeStatus.PENDING_ACCEPT.value
1879+
]
1880+
)
1881+
== 0
18431882
)
18441883

18451884
await wallet_environments.process_pending_states(
@@ -1883,15 +1922,36 @@ def only_ids(trades: list[TradeRecord]) -> list[bytes32]:
18831922
tx_config=wallet_environments.tx_config,
18841923
)
18851924
assert (
1886-
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 1
1925+
len(
1926+
[
1927+
o
1928+
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
1929+
if o.status == TradeStatus.PENDING_ACCEPT.value
1930+
]
1931+
)
1932+
== 1
18871933
)
18881934
await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, asset_id=bytes32.zeros)
18891935
assert (
1890-
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 1
1936+
len(
1937+
[
1938+
o
1939+
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
1940+
if o.status == TradeStatus.PENDING_ACCEPT.value
1941+
]
1942+
)
1943+
== 1
18911944
)
18921945
await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, asset_id=cat_asset_id)
18931946
assert (
1894-
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 0
1947+
len(
1948+
[
1949+
o
1950+
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
1951+
if o.status == TradeStatus.PENDING_ACCEPT.value
1952+
]
1953+
)
1954+
== 0
18951955
)
18961956

18971957
with pytest.raises(ValueError, match="not currently supported"):

chia/cmds/wallet_funcs.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
DIDUpdateMetadata,
6565
ExtendDerivationIndex,
6666
FungibleAsset,
67+
GetAllOffers,
6768
GetNextAddress,
6869
GetNotifications,
6970
GetOffer,
@@ -769,16 +770,20 @@ async def get_offers(
769770

770771
# Traverse offers page by page
771772
while True:
772-
new_records: list[TradeRecord] = await wallet_client.get_all_offers(
773-
start,
774-
end,
775-
sort_key="RELEVANCE" if sort_by_relevance else "CONFIRMED_AT_HEIGHT",
776-
reverse=reverse,
777-
file_contents=file_contents,
778-
exclude_my_offers=exclude_my_offers,
779-
exclude_taken_offers=exclude_taken_offers,
780-
include_completed=include_completed,
781-
)
773+
new_records: list[TradeRecord] = (
774+
await wallet_client.get_all_offers(
775+
GetAllOffers(
776+
start=uint16(start),
777+
end=uint16(end),
778+
sort_key="RELEVANCE" if sort_by_relevance else "CONFIRMED_AT_HEIGHT",
779+
reverse=reverse,
780+
file_contents=file_contents,
781+
exclude_my_offers=exclude_my_offers,
782+
exclude_taken_offers=exclude_taken_offers,
783+
include_completed=include_completed,
784+
)
785+
)
786+
).trade_records
782787
records.extend(new_records)
783788

784789
# If fewer records were returned than requested, we're done

chia/wallet/wallet_request_types.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1974,6 +1974,42 @@ def from_json_dict(cls, json_dict: dict[str, Any]) -> Self:
19741974
)
19751975

19761976

1977+
@streamable
1978+
@dataclass(frozen=True)
1979+
class GetAllOffers(Streamable):
1980+
start: uint16 = uint16(0)
1981+
end: uint16 = uint16(10)
1982+
exclude_my_offers: bool = False
1983+
exclude_taken_offers: bool = False
1984+
include_completed: bool = False
1985+
sort_key: Optional[str] = None
1986+
reverse: bool = False
1987+
file_contents: bool = False
1988+
1989+
1990+
@streamable
1991+
@dataclass(frozen=True)
1992+
class GetAllOffersResponse(Streamable):
1993+
offers: Optional[list[str]]
1994+
trade_records: list[TradeRecord]
1995+
1996+
def to_json_dict(self) -> dict[str, Any]:
1997+
return {**super().to_json_dict(), "trade_records": [tr.to_json_dict_convenience() for tr in self.trade_records]}
1998+
1999+
@classmethod
2000+
def from_json_dict(cls, json_dict: dict[str, Any]) -> Self:
2001+
return cls(
2002+
offers=json_dict["offers"],
2003+
trade_records=[
2004+
TradeRecord.from_json_dict_convenience(
2005+
json_tr,
2006+
bytes(Offer.from_bech32(json_dict["offers"][i])).hex() if json_dict["offers"] is not None else "",
2007+
)
2008+
for i, json_tr in enumerate(json_dict["trade_records"])
2009+
],
2010+
)
2011+
2012+
19772013
@streamable
19782014
@dataclass(frozen=True)
19792015
class CancelOfferResponse(TransactionEndpointResponse):

chia/wallet/wallet_rpc_api.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@
186186
GatherSigningInfo,
187187
GatherSigningInfoResponse,
188188
GenerateMnemonicResponse,
189+
GetAllOffers,
190+
GetAllOffersResponse,
189191
GetCATListResponse,
190192
GetCoinRecordsByNames,
191193
GetCoinRecordsByNamesResponse,
@@ -2370,36 +2372,32 @@ async def get_offer(self, request: GetOffer) -> GetOfferResponse:
23702372
trade_record,
23712373
)
23722374

2373-
async def get_all_offers(self, request: dict[str, Any]) -> EndpointResult:
2375+
@marshal
2376+
async def get_all_offers(self, request: GetAllOffers) -> GetAllOffersResponse:
23742377
trade_mgr = self.service.wallet_state_manager.trade_manager
23752378

2376-
start: int = request.get("start", 0)
2377-
end: int = request.get("end", 10)
2378-
exclude_my_offers: bool = request.get("exclude_my_offers", False)
2379-
exclude_taken_offers: bool = request.get("exclude_taken_offers", False)
2380-
include_completed: bool = request.get("include_completed", False)
2381-
sort_key: Optional[str] = request.get("sort_key", None)
2382-
reverse: bool = request.get("reverse", False)
2383-
file_contents: bool = request.get("file_contents", False)
2384-
23852379
all_trades = await trade_mgr.trade_store.get_trades_between(
2386-
start,
2387-
end,
2388-
sort_key=sort_key,
2389-
reverse=reverse,
2390-
exclude_my_offers=exclude_my_offers,
2391-
exclude_taken_offers=exclude_taken_offers,
2392-
include_completed=include_completed,
2380+
request.start,
2381+
request.end,
2382+
sort_key=request.sort_key,
2383+
reverse=request.reverse,
2384+
exclude_my_offers=request.exclude_my_offers,
2385+
exclude_taken_offers=request.exclude_taken_offers,
2386+
include_completed=request.include_completed,
23932387
)
23942388
result = []
2395-
offer_values: Optional[list[str]] = [] if file_contents else None
2389+
offer_values: Optional[list[str]] = [] if request.file_contents else None
23962390
for trade in all_trades:
2397-
result.append(trade.to_json_dict_convenience())
2398-
if file_contents and offer_values is not None:
2391+
result.append(trade)
2392+
if request.file_contents:
23992393
offer_to_return: bytes = trade.offer if trade.taken_offer is None else trade.taken_offer
2400-
offer_values.append(Offer.from_bytes(offer_to_return).to_bech32())
2394+
# semantics guarantee this to be not None
2395+
offer_values.append(Offer.from_bytes(offer_to_return).to_bech32()) # type: ignore[union-attr]
24012396

2402-
return {"trade_records": result, "offers": offer_values}
2397+
return GetAllOffersResponse(
2398+
trade_records=result,
2399+
offers=offer_values,
2400+
)
24032401

24042402
@marshal
24052403
async def get_offers_count(self, request: Empty) -> GetOffersCountResponse:

0 commit comments

Comments
 (0)