Skip to content

Commit 5e0a0df

Browse files
authored
Improve dedup (#19565)
* improve dedup * address review comments
1 parent 8ec82c5 commit 5e0a0df

File tree

2 files changed

+229
-20
lines changed

2 files changed

+229
-20
lines changed

chia/_tests/core/mempool/test_mempool_manager.py

Lines changed: 119 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
can_replace,
3535
check_removals,
3636
compute_assert_height,
37+
is_atom_canonical,
38+
is_clvm_canonical,
3739
optional_max,
3840
optional_min,
3941
)
@@ -88,6 +90,74 @@
8890
TEST_HEIGHT = uint32(5)
8991

9092

93+
@pytest.mark.parametrize("clvm_hex", ["80", "ff8080", "ff7f03", "ffff8080ff8080"])
94+
def test_clvm_canonical(clvm_hex: str) -> None:
95+
clvm_buf = bytes.fromhex(clvm_hex)
96+
assert is_clvm_canonical(clvm_buf)
97+
98+
99+
@pytest.mark.parametrize(
100+
"clvm_hex",
101+
[
102+
"fffe80",
103+
"c000",
104+
"c03f",
105+
"e00000",
106+
"e01fff",
107+
"f0000000",
108+
"f00fffff",
109+
"f800000000",
110+
"f807ffffff",
111+
"fc0000000000",
112+
"fc03ffffffff",
113+
"fe",
114+
"ff808080",
115+
],
116+
)
117+
def test_clvm_not_canonical(clvm_hex: str) -> None:
118+
clvm_buf = bytes.fromhex(clvm_hex)
119+
assert not is_clvm_canonical(clvm_buf)
120+
121+
122+
@pytest.mark.parametrize(
123+
"clvm_hex, expect",
124+
[
125+
("c000", 2 + 0),
126+
("c03f", 2 + 0x3F),
127+
("e00000", 3 + 0),
128+
("e01fff", 3 + 0x1FFF),
129+
("f0000000", 4 + 0),
130+
("f00fffff", 4 + 0xFFFFF),
131+
("f800000000", 5 + 0),
132+
("f807ffffff", 5 + 0x7FFFFFF),
133+
("fc0000000000", 6 + 0),
134+
("fc03ffffffff", 6 + 0x3FFFFFFFF),
135+
],
136+
)
137+
def test_atom_not_canonical(clvm_hex: str, expect: int) -> None:
138+
clvm_buf = bytes.fromhex(clvm_hex)
139+
atom_len, is_canonical = is_atom_canonical(clvm_buf, 0)
140+
assert atom_len == expect
141+
assert not is_canonical
142+
143+
144+
@pytest.mark.parametrize(
145+
"clvm_hex, expect",
146+
[
147+
("c040", 2 + 0x40),
148+
("e02000", 3 + 0x2000),
149+
("f0100000", 4 + 0x100000),
150+
("f808000000", 5 + 0x8000000),
151+
("fc0400000000", 6 + 0x400000000),
152+
],
153+
)
154+
def test_atom_canonical(clvm_hex: str, expect: int) -> None:
155+
clvm_buf = bytes.fromhex(clvm_hex)
156+
atom_len, is_canonical = is_atom_canonical(clvm_buf, 0)
157+
assert atom_len == expect
158+
assert is_canonical
159+
160+
91161
@dataclasses.dataclass(frozen=True)
92162
class TestBlockRecord:
93163
"""
@@ -760,8 +830,12 @@ def test_optional_max() -> None:
760830
assert optional_max(uint32(123), uint32(234)) == uint32(234)
761831

762832

763-
def mk_coin_spend(coin: Coin) -> CoinSpend:
764-
return make_spend(coin, SerializedProgram.to(None), SerializedProgram.to(None))
833+
def mk_coin_spend(coin: Coin, solution: Optional[str] = None) -> CoinSpend:
834+
return make_spend(
835+
coin,
836+
SerializedProgram.to(None),
837+
SerializedProgram.to(bytes.fromhex(solution) if solution is not None else None),
838+
)
765839

766840

767841
def mk_bcs(coin_spend: CoinSpend, flags: int = 0) -> BundleCoinSpend:
@@ -781,6 +855,7 @@ def mk_item(
781855
assert_height: Optional[int] = None,
782856
assert_before_height: Optional[int] = None,
783857
assert_before_seconds: Optional[int] = None,
858+
solution: Optional[str] = None,
784859
flags: list[int] = [],
785860
) -> MempoolItem:
786861
# we don't actually care about the puzzle and solutions for the purpose of
@@ -793,7 +868,8 @@ def mk_item(
793868
for c, f in zip(coins, flags):
794869
coin_id = c.name()
795870
spend_ids.append((coin_id, f))
796-
coin_spend = mk_coin_spend(c)
871+
coin_spend = mk_coin_spend(c, solution=solution)
872+
solution = None
797873
coin_spends.append(coin_spend)
798874
bundle_coin_spends[coin_id] = mk_bcs(coin_spend, f)
799875
spend_bundle = SpendBundle(coin_spends, G2Element())
@@ -1642,6 +1718,7 @@ async def get_coin_records(coin_ids: Collection[bytes32]) -> list[CoinRecord]:
16421718

16431719
mempool_manager = await instantiate_mempool_manager(get_coin_records)
16441720
# Create a bunch of mempool items that spend the coin in different ways
1721+
# only the first one will be accepted
16451722
for i in range(3):
16461723
_, _, result = await generate_and_add_spendbundle(
16471724
mempool_manager,
@@ -1651,10 +1728,13 @@ async def get_coin_records(coin_ids: Collection[bytes32]) -> list[CoinRecord]:
16511728
],
16521729
coin,
16531730
)
1654-
assert result[1] == MempoolInclusionStatus.SUCCESS
1655-
assert len(list(mempool_manager.mempool.get_items_by_coin_id(coin_id))) == 3
1656-
assert mempool_manager.mempool.size() == 3
1657-
assert len(list(mempool_manager.mempool.items_by_feerate())) == 3
1731+
if i == 0:
1732+
assert result[1] == MempoolInclusionStatus.SUCCESS
1733+
else:
1734+
assert result[1] == MempoolInclusionStatus.PENDING
1735+
assert len(list(mempool_manager.mempool.get_items_by_coin_id(coin_id))) == 1
1736+
assert mempool_manager.mempool.size() == 1
1737+
assert len(list(mempool_manager.mempool.items_by_feerate())) == 1
16581738
# Setup a new peak where the incoming block has spent the coin
16591739
# Mark this coin as spent
16601740
test_coin_records = {coin_id: CoinRecord(coin, uint32(0), TEST_HEIGHT, False, uint64(0))}
@@ -1833,7 +1913,7 @@ async def make_setup_and_coins(
18331913
sb_ef_name = sb_ef.name()
18341914
await send_to_mempool(full_node_api, sb_ef)
18351915
# Send also a transaction EG that spends E differently from DE and EF,
1836-
# so that it doesn't get deduplicated on E with them
1916+
# to ensure it's rejected by the mempool
18371917
conditions = [
18381918
[ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, e_coin.amount],
18391919
[ConditionOpcode.ASSERT_MY_COIN_ID, e_coin.name()],
@@ -1851,14 +1931,13 @@ async def make_setup_and_coins(
18511931
[tx_g] = action_scope.side_effects.transactions
18521932
assert tx_g.spend_bundle is not None
18531933
sb_e2g = SpendBundle.aggregate([sb_e2, tx_g.spend_bundle])
1854-
sb_e2g_name = sb_e2g.name()
1855-
await send_to_mempool(full_node_api, sb_e2g)
1934+
await send_to_mempool(full_node_api, sb_e2g, expecting_conflict=True)
18561935

18571936
# Make sure our coin IDs to spend bundles mappings are correct
18581937
assert get_sb_names_by_coin_id(full_node_api, coins[4].coin.name()) == {sb_de_name}
1859-
assert get_sb_names_by_coin_id(full_node_api, e_coin_id) == {sb_de_name, sb_ef_name, sb_e2g_name}
1938+
assert get_sb_names_by_coin_id(full_node_api, e_coin_id) == {sb_de_name, sb_ef_name}
18601939
assert get_sb_names_by_coin_id(full_node_api, coins[5].coin.name()) == {sb_ef_name}
1861-
assert get_sb_names_by_coin_id(full_node_api, g_coin_id) == {sb_e2g_name}
1940+
assert get_sb_names_by_coin_id(full_node_api, g_coin_id) == set()
18621941

18631942
await farm_a_block(full_node_api, wallet_node, ph)
18641943

@@ -2520,7 +2599,7 @@ async def test_advancing_ff(use_optimization: bool) -> None:
25202599
assert spend.latest_singleton_coin == spend_c.coin.name()
25212600

25222601

2523-
@pytest.mark.parametrize("flags", [ELIGIBLE_FOR_DEDUP, ELIGIBLE_FOR_FF])
2602+
@pytest.mark.parametrize("flags", [ELIGIBLE_FOR_DEDUP, ELIGIBLE_FOR_FF, ELIGIBLE_FOR_FF | ELIGIBLE_FOR_DEDUP])
25242603
@pytest.mark.anyio
25252604
async def test_check_removals_with_block_creation(flags: int) -> None:
25262605
LAUNCHER_ID = bytes32([1] * 32)
@@ -2560,6 +2639,18 @@ async def test_check_removals_with_block_creation(flags: int) -> None:
25602639
assert set(removals) == {singleton_spend.coin, TEST_COIN}
25612640

25622641

2642+
@pytest.mark.anyio
2643+
async def test_dedup_not_canonical() -> None:
2644+
# this is 1, but with a non-canonical encoding
2645+
coin_spend = mk_coin_spend(TEST_COIN, solution="c00101")
2646+
coins = TestCoins(coins=[], lineage={})
2647+
mempool_manager = await setup_mempool(coins)
2648+
sb = SpendBundle([coin_spend], G2Element())
2649+
sb_conds = make_test_conds(spend_ids=[(TEST_COIN, ELIGIBLE_FOR_DEDUP)])
2650+
bundle_add_info = await mempool_manager.add_spend_bundle(sb, sb_conds, sb.name(), uint32(1))
2651+
assert bundle_add_info.status == MempoolInclusionStatus.FAILED
2652+
2653+
25632654
def make_coin_record(coin: Coin, spent_block_index: int = 0) -> CoinRecord:
25642655
return CoinRecord(coin, uint32(0), uint32(spent_block_index), False, TEST_TIMESTAMP)
25652656

@@ -2626,6 +2717,21 @@ class CheckRemovalsCase:
26262717
conflicting_mempool_items={TEST_COIN_ID: [mk_item([TEST_COIN], flags=[ELIGIBLE_FOR_DEDUP])]},
26272718
expected_result=(None, []),
26282719
),
2720+
CheckRemovalsCase(
2721+
id="Dedup coin, Dedup mempool conflict with different solution",
2722+
removals={TEST_COIN_ID: TEST_COIN_RECORD},
2723+
bundle_coin_spends={TEST_COIN_ID: mk_bcs(mk_coin_spend(TEST_COIN, solution="ff8080"), ELIGIBLE_FOR_DEDUP)},
2724+
conflicting_mempool_items={TEST_COIN_ID: [mk_item([TEST_COIN], flags=[ELIGIBLE_FOR_DEDUP])]},
2725+
expected_result=(
2726+
Err.MEMPOOL_CONFLICT,
2727+
[
2728+
mk_item(
2729+
[TEST_COIN],
2730+
flags=[ELIGIBLE_FOR_DEDUP],
2731+
)
2732+
],
2733+
),
2734+
),
26292735
CheckRemovalsCase(
26302736
id="Regular coin, mempool conflict",
26312737
removals={TEST_COIN_ID: TEST_COIN_RECORD},

chia/full_node/mempool_manager.py

Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,92 @@ class NewPeakItem:
150150
QUOTE_EXECUTION_COST = 20
151151

152152

153+
def is_atom_canonical(clvm_buffer: bytes, offset: int) -> tuple[int, bool]:
154+
b = clvm_buffer[offset]
155+
if (b & 0b11000000) == 0b10000000:
156+
# 6 bits length prefix
157+
mask = 0b00111111
158+
prefix_len = 0
159+
min_value = 1
160+
elif (b & 0b11100000) == 0b11000000:
161+
# 5 + 8 bits length prefix
162+
mask = 0b00011111
163+
prefix_len = 1
164+
min_value = 1 << 6
165+
elif (b & 0b11110000) == 0b11100000:
166+
# 4 + 8 + 8 bits length prefix
167+
mask = 0b00001111
168+
prefix_len = 2
169+
min_value = 1 << (5 + 8)
170+
elif (b & 0b11111000) == 0b11110000:
171+
# 3 + 8 + 8 + 8 bits length prefix
172+
mask = 0b00000111
173+
prefix_len = 3
174+
min_value = 1 << (4 + 8 + 8)
175+
elif (b & 0b11111100) == 0b11111000:
176+
# 2 + 8 + 8 + 8 + 8 bits length prefix
177+
mask = 0b00000011
178+
prefix_len = 4
179+
min_value = 1 << (3 + 8 + 8 + 8)
180+
elif (b & 0b11111110) == 0b11111100:
181+
# 1 + 8 + 8 + 8 + 8 + 8 bits length prefix
182+
mask = 0b00000001
183+
prefix_len = 5
184+
min_value = 1 << (2 + 8 + 8 + 8 + 8)
185+
186+
atom_len = b & mask
187+
for i in range(prefix_len):
188+
atom_len <<= 8
189+
offset += 1
190+
atom_len |= clvm_buffer[offset]
191+
192+
return 1 + prefix_len + atom_len, atom_len >= min_value
193+
194+
195+
def is_clvm_canonical(clvm_buffer: bytes) -> bool:
196+
"""
197+
checks whether the CLVM serialization is all canonical representation.
198+
atoms can be serialized in more than one way by using more bytes than
199+
necessary to encode the length prefix. This functions ensures that all atoms are
200+
encoded with the shortest representation. back-references are not allowed
201+
and will make this function return false
202+
"""
203+
assert clvm_buffer != b""
204+
205+
offset = 0
206+
tokens_left = 1
207+
while True:
208+
b = clvm_buffer[offset]
209+
210+
# pair
211+
if b == 0xFF:
212+
tokens_left += 1
213+
offset += 1
214+
continue
215+
216+
# back references cannot be considered canonical, since they may be
217+
# encoded in many different ways
218+
if b == 0xFE:
219+
return False
220+
221+
# small atom or NIL
222+
if b <= 0x80:
223+
tokens_left -= 1
224+
offset += 1
225+
else:
226+
atom_len, canonical = is_atom_canonical(clvm_buffer, offset)
227+
if not canonical:
228+
return False
229+
tokens_left -= 1
230+
offset += atom_len
231+
232+
if tokens_left == 0:
233+
break
234+
235+
# if there's garbage at the end, it's not canonical
236+
return offset == len(clvm_buffer)
237+
238+
153239
def check_removals(
154240
removals: dict[bytes32, CoinRecord],
155241
bundle_coin_spends: dict[bytes32, BundleCoinSpend],
@@ -167,20 +253,35 @@ def check_removals(
167253
# 1. Checks if it's been spent already
168254
if removals[coin_id].spent and not coin_bcs.eligible_for_fast_forward:
169255
return Err.DOUBLE_SPEND, []
256+
170257
# 2. Checks if there's a mempool conflict
171-
# Only consider conflicts if the coin is not eligible for deduplication
172258
conflicting_items = get_items_by_coin_ids([coin_id])
173-
if not coin_bcs.eligible_for_fast_forward and not coin_bcs.eligible_for_dedup:
174-
conflicts.update(conflicting_items)
175-
continue
176259
for item in conflicting_items:
177260
if item in conflicts:
178261
continue
179262
conflict_bcs = item.bundle_coin_spends[coin_id]
180-
if (coin_bcs.eligible_for_fast_forward and not conflict_bcs.eligible_for_fast_forward) or (
181-
coin_bcs.eligible_for_dedup and not conflict_bcs.eligible_for_dedup
263+
# if the spend we're adding to the mempool is not DEDUP nor FF, it's
264+
# just a regular conflict
265+
if not coin_bcs.eligible_for_fast_forward and not coin_bcs.eligible_for_dedup:
266+
conflicts.add(item)
267+
268+
# if the spend we're adding is FF, but there's a conflicting spend
269+
# that isn't FF, they can't be chained, so that's a conflict
270+
elif coin_bcs.eligible_for_fast_forward and not conflict_bcs.eligible_for_fast_forward:
271+
conflicts.add(item)
272+
273+
# if the spend we're adding is DEDUP, but there's a conflicting spend
274+
# that isn't DEDUP, we cannot merge them, so that's a conflict
275+
elif coin_bcs.eligible_for_dedup and not conflict_bcs.eligible_for_dedup:
276+
conflicts.add(item)
277+
278+
# if the spend we're adding is DEDUP but the existing spend has a
279+
# different solution, we cannot merge them, so that's a conflict
280+
elif coin_bcs.eligible_for_dedup and bytes(coin_bcs.coin_spend.solution) != bytes(
281+
conflict_bcs.coin_spend.solution
182282
):
183283
conflicts.add(item)
284+
184285
if len(conflicts) > 0:
185286
return Err.MEMPOOL_CONFLICT, list(conflicts)
186287
return None, []
@@ -523,6 +624,8 @@ async def validate_spend_bundle(
523624
coin_id,
524625
EligibilityAndAdditions(is_eligible_for_dedup=False, spend_additions=[], ff_puzzle_hash=None),
525626
)
627+
628+
supports_dedup = eligibility_info.is_eligible_for_dedup and is_clvm_canonical(bytes(coin_spend.solution))
526629
mark_as_fast_forward = eligibility_info.ff_puzzle_hash is not None and supports_fast_forward(coin_spend)
527630
latest_singleton_coin = None
528631
if mark_as_fast_forward:
@@ -536,7 +639,7 @@ async def validate_spend_bundle(
536639
latest_singleton_coin = lineage_info.coin_id
537640
bundle_coin_spends[coin_id] = BundleCoinSpend(
538641
coin_spend=coin_spend,
539-
eligible_for_dedup=eligibility_info.is_eligible_for_dedup,
642+
eligible_for_dedup=supports_dedup,
540643
eligible_for_fast_forward=mark_as_fast_forward,
541644
additions=eligibility_info.spend_additions,
542645
latest_singleton_coin=latest_singleton_coin,

0 commit comments

Comments
 (0)