Skip to content

Commit 825ad74

Browse files
committed
Port get_all_offers to @marshal
1 parent 19c9c00 commit 825ad74

File tree

6 files changed

+169
-110
lines changed

6 files changed

+169
-110
lines changed

chia/_tests/cmds/wallet/test_wallet.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
ExtendDerivationIndex,
5757
ExtendDerivationIndexResponse,
5858
FungibleAsset,
59+
GetAllOffers,
60+
GetAllOffersResponse,
5961
GetCurrentDerivationIndexResponse,
6062
GetHeightInfoResponse,
6163
GetNextAddress,
@@ -948,32 +950,22 @@ def test_get_offers(capsys: object, get_test_cli_clients: tuple[TestRpcClients,
948950

949951
# set RPC Client
950952
class GetOffersRpcClient(TestWalletRpcClient):
951-
async def get_all_offers(
952-
self,
953-
start: int = 0,
954-
end: int = 50,
955-
sort_key: Optional[str] = None,
956-
reverse: bool = False,
957-
file_contents: bool = False,
958-
exclude_my_offers: bool = False,
959-
exclude_taken_offers: bool = False,
960-
include_completed: bool = False,
961-
) -> list[TradeRecord]:
953+
async def get_all_offers(self, request: GetAllOffers) -> GetAllOffersResponse:
962954
self.add_to_log(
963955
"get_all_offers",
964956
(
965-
start,
966-
end,
967-
sort_key,
968-
reverse,
969-
file_contents,
970-
exclude_my_offers,
971-
exclude_taken_offers,
972-
include_completed,
957+
request.start,
958+
request.end,
959+
request.sort_key,
960+
request.reverse,
961+
request.file_contents,
962+
request.exclude_my_offers,
963+
request.exclude_taken_offers,
964+
request.include_completed,
973965
),
974966
)
975967
records: list[TradeRecord] = []
976-
for i in reversed(range(start, end - 1)): # reversed to match the sort order
968+
for i in reversed(range(request.start, request.end - 1)): # reversed to match the sort order
977969
trade_offer = TradeRecord(
978970
confirmed_at_index=uint32(0),
979971
accepted_at_time=None,
@@ -997,7 +989,7 @@ async def get_all_offers(
997989
),
998990
)
999991
records.append(trade_offer)
1000-
return records
992+
return GetAllOffersResponse([], records)
1001993

1002994
inst_rpc_client = GetOffersRpcClient()
1003995
test_rpc_clients.wallet_rpc_client = inst_rpc_client

chia/_tests/wallet/rpc/test_wallet_rpc.py

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
DIDTransferDID,
131131
DIDUpdateMetadata,
132132
FungibleAsset,
133+
GetAllOffers,
133134
GetCoinRecordsByNames,
134135
GetNextAddress,
135136
GetNotifications,
@@ -1555,7 +1556,7 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_
15551556
CreateOfferForIDs(offer={str(1): "-5", cat_asset_id.hex(): "1"}, validate_only=True),
15561557
tx_config=wallet_environments.tx_config,
15571558
)
1558-
all_offers = await env_1.rpc_client.get_all_offers()
1559+
all_offers = (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
15591560
assert len(all_offers) == 0
15601561

15611562
driver_dict = {
@@ -1599,7 +1600,7 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_
15991600
assert offer_validity_response.id == offer.name()
16001601
assert offer_validity_response.valid
16011602

1602-
all_offers = await env_1.rpc_client.get_all_offers(file_contents=True)
1603+
all_offers = (await env_1.rpc_client.get_all_offers(GetAllOffers(file_contents=True))).trade_records
16031604
assert len(all_offers) == 1
16041605
assert TradeStatus(all_offers[0].status) == TradeStatus.PENDING_ACCEPT
16051606
assert all_offers[0].offer == bytes(offer)
@@ -1638,7 +1639,7 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_
16381639
CreateOfferForIDs(offer={str(1): "-5", str(cat_wallet_id): "1"}, fee=uint64(1)),
16391640
tx_config=wallet_environments.tx_config,
16401641
)
1641-
all_offers = await env_1.rpc_client.get_all_offers()
1642+
all_offers = (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
16421643
assert len(all_offers) == 2
16431644
offer_count = await env_1.rpc_client.get_offers_count()
16441645
assert offer_count.total == 2
@@ -1728,37 +1729,61 @@ def only_ids(trades: list[TradeRecord]) -> list[bytes32]:
17281729
return [t.trade_id for t in trades]
17291730

17301731
trade_record = (await env_1.rpc_client.get_offer(GetOffer(offer.name()))).trade_record
1731-
all_offers = await env_1.rpc_client.get_all_offers(include_completed=True) # confirmed at index descending
1732+
all_offers = ( # confirmed at index descending
1733+
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True))
1734+
).trade_records
17321735
assert len(all_offers) == 2
17331736
assert only_ids(all_offers) == only_ids([trade_record, new_trade_record])
1734-
all_offers = await env_1.rpc_client.get_all_offers(
1735-
include_completed=True, reverse=True
1736-
) # confirmed at index ascending
1737+
all_offers = ( # confirmed at index ascending
1738+
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, reverse=True))
1739+
).trade_records
17371740
assert only_ids(all_offers) == only_ids([new_trade_record, trade_record])
1738-
all_offers = await env_1.rpc_client.get_all_offers(include_completed=True, sort_key="RELEVANCE") # most relevant
1741+
all_offers = ( # most relevant
1742+
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, sort_key="RELEVANCE"))
1743+
).trade_records
17391744
assert only_ids(all_offers) == only_ids([new_trade_record, trade_record])
1740-
all_offers = await env_1.rpc_client.get_all_offers(
1741-
include_completed=True, sort_key="RELEVANCE", reverse=True
1742-
) # least relevant
1745+
all_offers = ( # least relevant
1746+
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, sort_key="RELEVANCE", reverse=True))
1747+
).trade_records
17431748
assert only_ids(all_offers) == only_ids([trade_record, new_trade_record])
17441749
# Test pagination
1745-
all_offers = await env_1.rpc_client.get_all_offers(include_completed=True, start=0, end=1)
1750+
all_offers = (
1751+
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, start=uint16(0), end=uint16(1)))
1752+
).trade_records
17461753
assert len(all_offers) == 1
1747-
all_offers = await env_1.rpc_client.get_all_offers(include_completed=True, start=50)
1754+
all_offers = (
1755+
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, start=uint16(10)))
1756+
).trade_records
17481757
assert len(all_offers) == 0
1749-
all_offers = await env_1.rpc_client.get_all_offers(include_completed=True, start=0, end=50)
1758+
all_offers = (
1759+
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, start=uint16(0), end=uint16(50)))
1760+
).trade_records
17501761
assert len(all_offers) == 2
17511762

17521763
await env_1.rpc_client.create_offer_for_ids(
17531764
CreateOfferForIDs(offer={str(1): "-5", cat_asset_id.hex(): "1"}, driver_dict=driver_dict),
17541765
tx_config=wallet_environments.tx_config,
17551766
)
17561767
assert (
1757-
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 2
1768+
len(
1769+
[
1770+
o
1771+
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
1772+
if o.status == TradeStatus.PENDING_ACCEPT.value
1773+
]
1774+
)
1775+
== 2
17581776
)
17591777
await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, batch_size=1)
17601778
assert (
1761-
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 0
1779+
len(
1780+
[
1781+
o
1782+
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
1783+
if o.status == TradeStatus.PENDING_ACCEPT.value
1784+
]
1785+
)
1786+
== 0
17621787
)
17631788
await wallet_environments.process_pending_states(
17641789
[
@@ -1795,11 +1820,25 @@ def only_ids(trades: list[TradeRecord]) -> list[bytes32]:
17951820
tx_config=wallet_environments.tx_config,
17961821
)
17971822
assert (
1798-
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 2
1823+
len(
1824+
[
1825+
o
1826+
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
1827+
if o.status == TradeStatus.PENDING_ACCEPT.value
1828+
]
1829+
)
1830+
== 2
17991831
)
18001832
await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, cancel_all=True)
18011833
assert (
1802-
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 0
1834+
len(
1835+
[
1836+
o
1837+
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
1838+
if o.status == TradeStatus.PENDING_ACCEPT.value
1839+
]
1840+
)
1841+
== 0
18031842
)
18041843

18051844
await wallet_environments.process_pending_states(
@@ -1843,15 +1882,36 @@ def only_ids(trades: list[TradeRecord]) -> list[bytes32]:
18431882
tx_config=wallet_environments.tx_config,
18441883
)
18451884
assert (
1846-
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 1
1885+
len(
1886+
[
1887+
o
1888+
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
1889+
if o.status == TradeStatus.PENDING_ACCEPT.value
1890+
]
1891+
)
1892+
== 1
18471893
)
18481894
await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, asset_id=bytes32.zeros)
18491895
assert (
1850-
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 1
1896+
len(
1897+
[
1898+
o
1899+
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
1900+
if o.status == TradeStatus.PENDING_ACCEPT.value
1901+
]
1902+
)
1903+
== 1
18511904
)
18521905
await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, asset_id=cat_asset_id)
18531906
assert (
1854-
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 0
1907+
len(
1908+
[
1909+
o
1910+
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
1911+
if o.status == TradeStatus.PENDING_ACCEPT.value
1912+
]
1913+
)
1914+
== 0
18551915
)
18561916

18571917
with pytest.raises(ValueError, match="not currently supported"):

chia/cmds/wallet_funcs.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
DIDUpdateMetadata,
6565
ExtendDerivationIndex,
6666
FungibleAsset,
67+
GetAllOffers,
6768
GetNextAddress,
6869
GetNotifications,
6970
GetOffer,
@@ -769,16 +770,20 @@ async def get_offers(
769770

770771
# Traverse offers page by page
771772
while True:
772-
new_records: list[TradeRecord] = await wallet_client.get_all_offers(
773-
start,
774-
end,
775-
sort_key="RELEVANCE" if sort_by_relevance else "CONFIRMED_AT_HEIGHT",
776-
reverse=reverse,
777-
file_contents=file_contents,
778-
exclude_my_offers=exclude_my_offers,
779-
exclude_taken_offers=exclude_taken_offers,
780-
include_completed=include_completed,
781-
)
773+
new_records: list[TradeRecord] = (
774+
await wallet_client.get_all_offers(
775+
GetAllOffers(
776+
start=uint16(start),
777+
end=uint16(end),
778+
sort_key="RELEVANCE" if sort_by_relevance else "CONFIRMED_AT_HEIGHT",
779+
reverse=reverse,
780+
file_contents=file_contents,
781+
exclude_my_offers=exclude_my_offers,
782+
exclude_taken_offers=exclude_taken_offers,
783+
include_completed=include_completed,
784+
)
785+
)
786+
).trade_records
782787
records.extend(new_records)
783788

784789
# If fewer records were returned than requested, we're done

chia/wallet/wallet_request_types.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1982,6 +1982,42 @@ def from_json_dict(cls, json_dict: dict[str, Any]) -> Self:
19821982
)
19831983

19841984

1985+
@streamable
1986+
@dataclass(frozen=True)
1987+
class GetAllOffers(Streamable):
1988+
start: uint16 = uint16(0)
1989+
end: uint16 = uint16(10)
1990+
exclude_my_offers: bool = False
1991+
exclude_taken_offers: bool = False
1992+
include_completed: bool = False
1993+
sort_key: Optional[str] = None
1994+
reverse: bool = False
1995+
file_contents: bool = False
1996+
1997+
1998+
@streamable
1999+
@dataclass(frozen=True)
2000+
class GetAllOffersResponse(Streamable):
2001+
offers: Optional[list[str]]
2002+
trade_records: list[TradeRecord]
2003+
2004+
def to_json_dict(self) -> dict[str, Any]:
2005+
return {**super().to_json_dict(), "trade_records": [tr.to_json_dict_convenience() for tr in self.trade_records]}
2006+
2007+
@classmethod
2008+
def from_json_dict(cls, json_dict: dict[str, Any]) -> Self:
2009+
return cls(
2010+
offers=json_dict["offers"],
2011+
trade_records=[
2012+
TradeRecord.from_json_dict_convenience(
2013+
json_tr,
2014+
bytes(Offer.from_bech32(json_dict["offers"][i])).hex() if json_dict["offers"] is not None else "",
2015+
)
2016+
for i, json_tr in enumerate(json_dict["trade_records"])
2017+
],
2018+
)
2019+
2020+
19852021
@streamable
19862022
@dataclass(frozen=True)
19872023
class CancelOfferResponse(TransactionEndpointResponse):

chia/wallet/wallet_rpc_api.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@
187187
GatherSigningInfo,
188188
GatherSigningInfoResponse,
189189
GenerateMnemonicResponse,
190+
GetAllOffers,
191+
GetAllOffersResponse,
190192
GetCATListResponse,
191193
GetCoinRecordsByNames,
192194
GetCoinRecordsByNamesResponse,
@@ -2369,36 +2371,32 @@ async def get_offer(self, request: GetOffer) -> GetOfferResponse:
23692371
trade_record,
23702372
)
23712373

2372-
async def get_all_offers(self, request: dict[str, Any]) -> EndpointResult:
2374+
@marshal
2375+
async def get_all_offers(self, request: GetAllOffers) -> GetAllOffersResponse:
23732376
trade_mgr = self.service.wallet_state_manager.trade_manager
23742377

2375-
start: int = request.get("start", 0)
2376-
end: int = request.get("end", 10)
2377-
exclude_my_offers: bool = request.get("exclude_my_offers", False)
2378-
exclude_taken_offers: bool = request.get("exclude_taken_offers", False)
2379-
include_completed: bool = request.get("include_completed", False)
2380-
sort_key: Optional[str] = request.get("sort_key", None)
2381-
reverse: bool = request.get("reverse", False)
2382-
file_contents: bool = request.get("file_contents", False)
2383-
23842378
all_trades = await trade_mgr.trade_store.get_trades_between(
2385-
start,
2386-
end,
2387-
sort_key=sort_key,
2388-
reverse=reverse,
2389-
exclude_my_offers=exclude_my_offers,
2390-
exclude_taken_offers=exclude_taken_offers,
2391-
include_completed=include_completed,
2379+
request.start,
2380+
request.end,
2381+
sort_key=request.sort_key,
2382+
reverse=request.reverse,
2383+
exclude_my_offers=request.exclude_my_offers,
2384+
exclude_taken_offers=request.exclude_taken_offers,
2385+
include_completed=request.include_completed,
23922386
)
23932387
result = []
2394-
offer_values: Optional[list[str]] = [] if file_contents else None
2388+
offer_values: Optional[list[str]] = [] if request.file_contents else None
23952389
for trade in all_trades:
2396-
result.append(trade.to_json_dict_convenience())
2397-
if file_contents and offer_values is not None:
2390+
result.append(trade)
2391+
if request.file_contents:
23982392
offer_to_return: bytes = trade.offer if trade.taken_offer is None else trade.taken_offer
2399-
offer_values.append(Offer.from_bytes(offer_to_return).to_bech32())
2393+
# semantics guarantee this to be not None
2394+
offer_values.append(Offer.from_bytes(offer_to_return).to_bech32()) # type: ignore[union-attr]
24002395

2401-
return {"trade_records": result, "offers": offer_values}
2396+
return GetAllOffersResponse(
2397+
trade_records=result,
2398+
offers=offer_values,
2399+
)
24022400

24032401
@marshal
24042402
async def get_offers_count(self, request: Empty) -> GetOffersCountResponse:

0 commit comments

Comments
 (0)