Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 33 additions & 36 deletions chia/_tests/cmds/wallet/test_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,14 @@
ExtendDerivationIndex,
ExtendDerivationIndexResponse,
FungibleAsset,
GetAllOffers,
GetAllOffersResponse,
GetCurrentDerivationIndexResponse,
GetHeightInfoResponse,
GetNextAddress,
GetNextAddressResponse,
GetOffer,
GetOfferResponse,
GetTransaction,
GetTransactions,
GetTransactionsResponse,
Expand Down Expand Up @@ -946,32 +950,22 @@ def test_get_offers(capsys: object, get_test_cli_clients: tuple[TestRpcClients,

# set RPC Client
class GetOffersRpcClient(TestWalletRpcClient):
async def get_all_offers(
self,
start: int = 0,
end: int = 50,
sort_key: Optional[str] = None,
reverse: bool = False,
file_contents: bool = False,
exclude_my_offers: bool = False,
exclude_taken_offers: bool = False,
include_completed: bool = False,
) -> list[TradeRecord]:
async def get_all_offers(self, request: GetAllOffers) -> GetAllOffersResponse:
self.add_to_log(
"get_all_offers",
(
start,
end,
sort_key,
reverse,
file_contents,
exclude_my_offers,
exclude_taken_offers,
include_completed,
request.start,
request.end,
request.sort_key,
request.reverse,
request.file_contents,
request.exclude_my_offers,
request.exclude_taken_offers,
request.include_completed,
),
)
records: list[TradeRecord] = []
for i in reversed(range(start, end - 1)): # reversed to match the sort order
for i in reversed(range(request.start, request.end - 1)): # reversed to match the sort order
trade_offer = TradeRecord(
confirmed_at_index=uint32(0),
accepted_at_time=None,
Expand All @@ -995,7 +989,7 @@ async def get_all_offers(
),
)
records.append(trade_offer)
return records
return GetAllOffersResponse([], records)

inst_rpc_client = GetOffersRpcClient()
test_rpc_clients.wallet_rpc_client = inst_rpc_client
Expand Down Expand Up @@ -1163,22 +1157,25 @@ def test_cancel_offer(capsys: object, get_test_cli_clients: tuple[TestRpcClients

# set RPC Client
class CancelOfferRpcClient(TestWalletRpcClient):
async def get_offer(self, trade_id: bytes32, file_contents: bool = False) -> TradeRecord:
self.add_to_log("get_offer", (trade_id, file_contents))
async def get_offer(self, request: GetOffer) -> GetOfferResponse:
self.add_to_log("get_offer", (request.trade_id, request.file_contents))
offer = Offer.from_bech32(test_offer_file_bech32)
return TradeRecord(
confirmed_at_index=uint32(0),
accepted_at_time=uint64(0),
created_at_time=uint64(12345678),
is_my_offer=True,
sent=uint32(0),
sent_to=[],
offer=bytes(offer),
taken_offer=None,
coins_of_interest=offer.get_involved_coins(),
trade_id=offer.name(),
status=uint32(TradeStatus.PENDING_ACCEPT.value),
valid_times=ConditionValidTimes(),
return GetOfferResponse(
test_offer_file_bech32,
TradeRecord(
confirmed_at_index=uint32(0),
accepted_at_time=uint64(0),
created_at_time=uint64(12345678),
is_my_offer=True,
sent=uint32(0),
sent_to=[],
offer=bytes(offer),
taken_offer=None,
coins_of_interest=offer.get_involved_coins(),
trade_id=offer.name(),
status=uint32(TradeStatus.PENDING_ACCEPT.value),
valid_times=ConditionValidTimes(),
),
)

async def cancel_offer(
Expand Down
111 changes: 86 additions & 25 deletions chia/_tests/wallet/rpc/test_wallet_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,11 @@
DIDTransferDID,
DIDUpdateMetadata,
FungibleAsset,
GetAllOffers,
GetCoinRecordsByNames,
GetNextAddress,
GetNotifications,
GetOffer,
GetOfferSummary,
GetPrivateKey,
GetSpendableCoins,
Expand Down Expand Up @@ -1554,7 +1556,7 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_
CreateOfferForIDs(offer={str(1): "-5", cat_asset_id.hex(): "1"}, validate_only=True),
tx_config=wallet_environments.tx_config,
)
all_offers = await env_1.rpc_client.get_all_offers()
all_offers = (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
assert len(all_offers) == 0

driver_dict = {
Expand Down Expand Up @@ -1598,7 +1600,7 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_
assert offer_validity_response.id == offer.name()
assert offer_validity_response.valid

all_offers = await env_1.rpc_client.get_all_offers(file_contents=True)
all_offers = (await env_1.rpc_client.get_all_offers(GetAllOffers(file_contents=True))).trade_records
assert len(all_offers) == 1
assert TradeStatus(all_offers[0].status) == TradeStatus.PENDING_ACCEPT
assert all_offers[0].offer == bytes(offer)
Expand All @@ -1622,22 +1624,22 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_

await env_1.rpc_client.cancel_offer(offer.name(), wallet_environments.tx_config, secure=False)

trade_record = await env_1.rpc_client.get_offer(offer.name(), file_contents=True)
trade_record = (await env_1.rpc_client.get_offer(GetOffer(offer.name(), file_contents=True))).trade_record
assert trade_record.offer == bytes(offer)
assert TradeStatus(trade_record.status) == TradeStatus.CANCELLED

failed_cancel_res = await env_1.rpc_client.cancel_offer(
offer.name(), wallet_environments.tx_config, fee=uint64(1), secure=True
)

trade_record = await env_1.rpc_client.get_offer(offer.name())
trade_record = (await env_1.rpc_client.get_offer(GetOffer(offer.name()))).trade_record
assert TradeStatus(trade_record.status) == TradeStatus.PENDING_CANCEL

create_res = await env_1.rpc_client.create_offer_for_ids(
CreateOfferForIDs(offer={str(1): "-5", str(cat_wallet_id): "1"}, fee=uint64(1)),
tx_config=wallet_environments.tx_config,
)
all_offers = await env_1.rpc_client.get_all_offers()
all_offers = (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
assert len(all_offers) == 2
offer_count = await env_1.rpc_client.get_offers_count()
assert offer_count.total == 2
Expand Down Expand Up @@ -1717,7 +1719,7 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_
)

async def is_trade_confirmed(client: WalletRpcClient, offer: Offer) -> bool:
trade_record = await client.get_offer(offer.name())
trade_record = (await client.get_offer(GetOffer(offer.name()))).trade_record
return TradeStatus(trade_record.status) == TradeStatus.CONFIRMED

await time_out_assert(15, is_trade_confirmed, True, env_1.rpc_client, offer)
Expand All @@ -1726,38 +1728,62 @@ async def is_trade_confirmed(client: WalletRpcClient, offer: Offer) -> bool:
def only_ids(trades: list[TradeRecord]) -> list[bytes32]:
return [t.trade_id for t in trades]

trade_record = await env_1.rpc_client.get_offer(offer.name())
all_offers = await env_1.rpc_client.get_all_offers(include_completed=True) # confirmed at index descending
trade_record = (await env_1.rpc_client.get_offer(GetOffer(offer.name()))).trade_record
all_offers = ( # confirmed at index descending
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True))
).trade_records
assert len(all_offers) == 2
assert only_ids(all_offers) == only_ids([trade_record, new_trade_record])
all_offers = await env_1.rpc_client.get_all_offers(
include_completed=True, reverse=True
) # confirmed at index ascending
all_offers = ( # confirmed at index ascending
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, reverse=True))
).trade_records
assert only_ids(all_offers) == only_ids([new_trade_record, trade_record])
all_offers = await env_1.rpc_client.get_all_offers(include_completed=True, sort_key="RELEVANCE") # most relevant
all_offers = ( # most relevant
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, sort_key="RELEVANCE"))
).trade_records
assert only_ids(all_offers) == only_ids([new_trade_record, trade_record])
all_offers = await env_1.rpc_client.get_all_offers(
include_completed=True, sort_key="RELEVANCE", reverse=True
) # least relevant
all_offers = ( # least relevant
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, sort_key="RELEVANCE", reverse=True))
).trade_records
assert only_ids(all_offers) == only_ids([trade_record, new_trade_record])
# Test pagination
all_offers = await env_1.rpc_client.get_all_offers(include_completed=True, start=0, end=1)
all_offers = (
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, start=uint16(0), end=uint16(1)))
).trade_records
assert len(all_offers) == 1
all_offers = await env_1.rpc_client.get_all_offers(include_completed=True, start=50)
all_offers = (
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, start=uint16(10)))
).trade_records
assert len(all_offers) == 0
all_offers = await env_1.rpc_client.get_all_offers(include_completed=True, start=0, end=50)
all_offers = (
await env_1.rpc_client.get_all_offers(GetAllOffers(include_completed=True, start=uint16(0), end=uint16(50)))
).trade_records
assert len(all_offers) == 2

await env_1.rpc_client.create_offer_for_ids(
CreateOfferForIDs(offer={str(1): "-5", cat_asset_id.hex(): "1"}, driver_dict=driver_dict),
tx_config=wallet_environments.tx_config,
)
assert (
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 2
len(
[
o
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
if o.status == TradeStatus.PENDING_ACCEPT.value
]
)
== 2
)
await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, batch_size=1)
assert (
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 0
len(
[
o
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
if o.status == TradeStatus.PENDING_ACCEPT.value
]
)
== 0
)
await wallet_environments.process_pending_states(
[
Expand Down Expand Up @@ -1794,11 +1820,25 @@ def only_ids(trades: list[TradeRecord]) -> list[bytes32]:
tx_config=wallet_environments.tx_config,
)
assert (
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 2
len(
[
o
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
if o.status == TradeStatus.PENDING_ACCEPT.value
]
)
== 2
)
await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, cancel_all=True)
assert (
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 0
len(
[
o
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
if o.status == TradeStatus.PENDING_ACCEPT.value
]
)
== 0
)

await wallet_environments.process_pending_states(
Expand Down Expand Up @@ -1842,15 +1882,36 @@ def only_ids(trades: list[TradeRecord]) -> list[bytes32]:
tx_config=wallet_environments.tx_config,
)
assert (
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 1
len(
[
o
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
if o.status == TradeStatus.PENDING_ACCEPT.value
]
)
== 1
)
await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, asset_id=bytes32.zeros)
assert (
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 1
len(
[
o
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
if o.status == TradeStatus.PENDING_ACCEPT.value
]
)
== 1
)
await env_1.rpc_client.cancel_offers(wallet_environments.tx_config, asset_id=cat_asset_id)
assert (
len([o for o in await env_1.rpc_client.get_all_offers() if o.status == TradeStatus.PENDING_ACCEPT.value]) == 0
len(
[
o
for o in (await env_1.rpc_client.get_all_offers(GetAllOffers())).trade_records
if o.status == TradeStatus.PENDING_ACCEPT.value
]
)
== 0
)

with pytest.raises(ValueError, match="not currently supported"):
Expand Down
30 changes: 18 additions & 12 deletions chia/cmds/wallet_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@
DIDUpdateMetadata,
ExtendDerivationIndex,
FungibleAsset,
GetAllOffers,
GetNextAddress,
GetNotifications,
GetOffer,
GetTransaction,
GetTransactions,
GetWalletBalance,
Expand Down Expand Up @@ -768,16 +770,20 @@ async def get_offers(

# Traverse offers page by page
while True:
new_records: list[TradeRecord] = await wallet_client.get_all_offers(
start,
end,
sort_key="RELEVANCE" if sort_by_relevance else "CONFIRMED_AT_HEIGHT",
reverse=reverse,
file_contents=file_contents,
exclude_my_offers=exclude_my_offers,
exclude_taken_offers=exclude_taken_offers,
include_completed=include_completed,
)
new_records: list[TradeRecord] = (
await wallet_client.get_all_offers(
GetAllOffers(
start=uint16(start),
end=uint16(end),
sort_key="RELEVANCE" if sort_by_relevance else "CONFIRMED_AT_HEIGHT",
reverse=reverse,
file_contents=file_contents,
exclude_my_offers=exclude_my_offers,
exclude_taken_offers=exclude_taken_offers,
include_completed=include_completed,
)
)
).trade_records
records.extend(new_records)

# If fewer records were returned than requested, we're done
Expand All @@ -787,7 +793,7 @@ async def get_offers(
start = end
end += batch_size
else:
records = [await wallet_client.get_offer(offer_id, file_contents)]
records = [(await wallet_client.get_offer(GetOffer(offer_id, file_contents))).trade_record]
if filepath is not None:
with open(pathlib.Path(filepath), "w") as file:
file.write(Offer.from_bytes(records[0].offer).to_bech32())
Expand Down Expand Up @@ -919,7 +925,7 @@ async def cancel_offer(
condition_valid_times: ConditionValidTimes,
) -> list[TransactionRecord]:
async with get_wallet_client(root_path, wallet_rpc_port, fp) as (wallet_client, fingerprint, config):
trade_record = await wallet_client.get_offer(offer_id, file_contents=True)
trade_record = (await wallet_client.get_offer(GetOffer(offer_id, file_contents=True))).trade_record
await print_trade_record(trade_record, wallet_client, summaries=True)

cli_confirm(f"Are you sure you wish to cancel offer with ID: {trade_record.trade_id}? (y/n): ")
Expand Down
5 changes: 4 additions & 1 deletion chia/data_layer/data_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
DLUpdateMultiple,
DLUpdateMultipleUpdates,
DLUpdateRoot,
GetOffer,
LauncherRootPair,
LogIn,
TakeOffer,
Expand Down Expand Up @@ -1274,7 +1275,9 @@ async def cancel_offer(self, trade_id: bytes32, secure: bool, fee: uint64) -> No
store_ids: list[bytes32] = []

if not secure:
trade_record = await self.wallet_rpc.get_offer(trade_id=trade_id, file_contents=True)
trade_record = (
await self.wallet_rpc.get_offer(GetOffer(trade_id=trade_id, file_contents=True))
).trade_record
trading_offer = TradingOffer.from_bytes(trade_record.offer)
summary = await DataLayerWallet.get_offer_summary(offer=trading_offer)
store_ids = [offered.launcher_id for offered in summary.offered]
Expand Down
Loading
Loading