Skip to content

Commit 87d2b11

Browse files
authored
attestation: some simplifications (#234)
* attestation: some simplifications * doc touchup * cleanup duplication test * better doc
1 parent 29b45d8 commit 87d2b11

File tree

6 files changed

+112
-108
lines changed

6 files changed

+112
-108
lines changed

packages/testing/src/consensus_testing/test_fixtures/state_transition.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pydantic import ConfigDict, PrivateAttr, field_serializer
66

7+
from lean_spec.subspecs.containers.attestation import Attestation
78
from lean_spec.subspecs.containers.block.block import Block, BlockBody
89
from lean_spec.subspecs.containers.block.types import AggregatedAttestations
910
from lean_spec.subspecs.containers.state.state import State
@@ -232,14 +233,17 @@ def _build_block_from_spec(self, spec: BlockSpec, state: State) -> tuple[Block,
232233
return block, None
233234

234235
# Build the block using the state for standard case
236+
#
237+
# Convert aggregated attestations to plain attestations to build block
238+
plain_attestations = [
239+
Attestation(validator_id=vid, data=agg.data)
240+
for agg in aggregated_attestations
241+
for vid in agg.aggregation_bits.to_validator_indices()
242+
]
235243
block, post_state, _, _ = state.build_block(
236244
slot=spec.slot,
237245
proposer_index=proposer_index,
238246
parent_root=parent_root,
239-
attestations=[
240-
attestation
241-
for aggregated_attestation in aggregated_attestations
242-
for attestation in aggregated_attestation.to_plain()
243-
],
247+
attestations=plain_attestations,
244248
)
245249
return block, post_state

src/lean_spec/subspecs/containers/attestation/attestation.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,20 +81,6 @@ class AggregatedAttestation(Container):
8181
committee assignments.
8282
"""
8383

84-
def to_plain(self) -> list[Attestation]:
85-
"""
86-
Expand this aggregated attestation into plain per-validator attestations.
87-
88-
Returns:
89-
One `Attestation` per participating validator index, all sharing the same
90-
`AttestationData`.
91-
"""
92-
validator_indices = self.aggregation_bits.to_validator_indices()
93-
return [
94-
Attestation(validator_id=validator_id, data=self.data)
95-
for validator_id in validator_indices
96-
]
97-
9884
@classmethod
9985
def aggregate_by_data(
10086
cls,

src/lean_spec/subspecs/containers/block/types.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from lean_spec.types import SSZList
44

55
from ...chain.config import VALIDATOR_REGISTRY_LIMIT
6-
from ..attestation import AggregatedAttestation, NaiveAggregatedSignature
6+
from ..attestation import AggregatedAttestation, AttestationData, NaiveAggregatedSignature
77

88

99
class AggregatedAttestations(SSZList):
@@ -12,6 +12,21 @@ class AggregatedAttestations(SSZList):
1212
ELEMENT_TYPE = AggregatedAttestation
1313
LIMIT = int(VALIDATOR_REGISTRY_LIMIT)
1414

15+
def __getitem__(self, index: int) -> AggregatedAttestation:
16+
"""Access an aggregated attestation by index with proper typing."""
17+
item = self.data[index]
18+
assert isinstance(item, AggregatedAttestation)
19+
return item
20+
21+
def has_duplicate_data(self) -> bool:
22+
"""Check if any two attestations share the same AttestationData."""
23+
seen: set[AttestationData] = set()
24+
for attestation in self:
25+
if attestation.data in seen:
26+
return True
27+
seen.add(attestation.data)
28+
return False
29+
1530

1631
class AttestationSignatures(SSZList):
1732
"""List of per-attestation naive signature lists aligned with block body attestations."""

src/lean_spec/subspecs/containers/state/state.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,7 @@
1212
is_proposer,
1313
)
1414

15-
from ..attestation import (
16-
AggregatedAttestation,
17-
Attestation,
18-
SignedAttestation,
19-
)
15+
from ..attestation import AggregatedAttestation, Attestation, SignedAttestation
2016

2117
if TYPE_CHECKING:
2218
from lean_spec.subspecs.xmss.containers import Signature
@@ -365,22 +361,23 @@ def process_block(self, block: Block) -> "State":
365361
# First process the block header.
366362
state = self.process_block_header(block)
367363

368-
# Process justification attestations by converting aggregated payloads
369-
attestations: list[Attestation] = []
370-
attestations_data = set()
371-
for aggregated_att in block.body.attestations:
372-
# No partial aggregation is allowed.
373-
if aggregated_att.data in attestations_data:
374-
raise AssertionError("Block contains duplicate AttestationData")
375-
376-
attestations_data.add(aggregated_att.data)
377-
attestations.extend(aggregated_att.to_plain())
364+
# Reject blocks with duplicate attestation data
365+
#
366+
# Each aggregated attestation in a block must refer to a unique AttestationData.
367+
# Duplicates would allow the same vote to be counted multiple times, breaking
368+
# the integrity of the justification tally.
369+
#
370+
# This is a protocol-level invariant: honest proposers never include duplicates,
371+
# and validators must reject blocks that violate this rule.
372+
assert not block.body.attestations.has_duplicate_data(), (
373+
"Block contains duplicate AttestationData"
374+
)
378375

379-
return state.process_attestations(attestations)
376+
return state.process_attestations(block.body.attestations)
380377

381378
def process_attestations(
382379
self,
383-
attestations: list[Attestation],
380+
attestations: Iterable[AggregatedAttestation],
384381
) -> "State":
385382
"""
386383
Apply attestations and update justification/finalization
@@ -393,8 +390,8 @@ def process_attestations(
393390
394391
Parameters
395392
----------
396-
attestations : Attestations
397-
The list of attestations to process.
393+
attestations : Iterable[AggregatedAttestation]
394+
The aggregated attestations to process.
398395
399396
Returns:
400397
-------
@@ -507,12 +504,13 @@ def process_attestations(
507504
if target.root not in justifications:
508505
justifications[target.root] = [Boolean(False)] * len(self.validators)
509506

510-
# Mark that this validator has voted for the target.
507+
# Mark that each validator in this aggregation has voted for the target.
511508
#
512509
# A vote is represented as a boolean flag.
513510
# If it was previously absent, flip it to True.
514-
if not justifications[target.root][attestation.validator_id]:
515-
justifications[target.root][attestation.validator_id] = Boolean(True)
511+
for validator_id in attestation.aggregation_bits.to_validator_indices():
512+
if not justifications[target.root][validator_id]:
513+
justifications[target.root][validator_id] = Boolean(True)
516514

517515
# Check whether the vote count crosses the supermajority threshold
518516
#

src/lean_spec/subspecs/forkchoice/store.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -467,19 +467,17 @@ def on_block(self, signed_block_with_attestation: SignedBlockWithAttestation) ->
467467
for aggregated_attestation, aggregated_signature in zip(
468468
aggregated_attestations, attestation_signatures, strict=True
469469
):
470-
plain_attestations = aggregated_attestation.to_plain()
470+
validator_ids = aggregated_attestation.aggregation_bits.to_validator_indices()
471471

472-
assert len(plain_attestations) == len(aggregated_signature), (
472+
assert len(validator_ids) == len(aggregated_signature), (
473473
"Aggregated attestation signature count mismatch"
474474
)
475475

476-
for attestation, signature in zip(
477-
plain_attestations, aggregated_signature, strict=True
478-
):
476+
for validator_id, signature in zip(validator_ids, aggregated_signature, strict=True):
479477
store = store.on_attestation(
480478
signed_attestation=SignedAttestation(
481-
validator_id=attestation.validator_id,
482-
message=attestation.data,
479+
validator_id=validator_id,
480+
message=aggregated_attestation.data,
483481
signature=signature,
484482
),
485483
is_from_block=True,

tests/lean_spec/subspecs/containers/test_attestation_aggregation.py

Lines changed: 62 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,77 @@
1313
from lean_spec.types import Bytes32, Uint64
1414

1515

16-
class TestAttestationAggregation:
17-
"""Test proper attestation aggregation by common data."""
16+
class TestAggregationBits:
17+
"""Test aggregation bits functionality."""
1818

1919
def test_reject_empty_aggregation_bits(self) -> None:
2020
"""Validate aggregated attestation must include at least one validator."""
2121
bits = AggregationBits(data=[False, False, False])
2222
with pytest.raises(AssertionError, match="at least one validator"):
2323
bits.to_validator_indices()
2424

25+
def test_to_validator_indices_single_bit(self) -> None:
26+
"""Test conversion with a single bit set."""
27+
bits = AggregationBits(data=[False, True, False])
28+
indices = bits.to_validator_indices()
29+
assert indices == [Uint64(1)]
30+
31+
def test_to_validator_indices_multiple_bits(self) -> None:
32+
"""Test conversion with multiple bits set."""
33+
bits = AggregationBits(data=[True, False, True, True, False])
34+
indices = bits.to_validator_indices()
35+
assert indices == [Uint64(0), Uint64(2), Uint64(3)]
36+
37+
def test_from_validator_indices_roundtrip(self) -> None:
38+
"""Test that from_validator_indices and to_validator_indices are inverses."""
39+
original_indices = [Uint64(1), Uint64(5), Uint64(7)]
40+
bits = AggregationBits.from_validator_indices(original_indices)
41+
recovered_indices = bits.to_validator_indices()
42+
assert recovered_indices == original_indices
43+
44+
45+
class TestAggregatedAttestation:
46+
"""Test aggregated attestation structure."""
47+
48+
def test_aggregated_attestation_structure(self) -> None:
49+
"""Test that aggregated attestation properly stores bits and data."""
50+
att_data = AttestationData(
51+
slot=Slot(5),
52+
head=Checkpoint(root=Bytes32.zero(), slot=Slot(4)),
53+
target=Checkpoint(root=Bytes32.zero(), slot=Slot(3)),
54+
source=Checkpoint(root=Bytes32.zero(), slot=Slot(2)),
55+
)
56+
57+
bits = AggregationBits.from_validator_indices([Uint64(2), Uint64(7)])
58+
agg = AggregatedAttestation(aggregation_bits=bits, data=att_data)
59+
60+
# Verify we can extract validator indices
61+
indices = agg.aggregation_bits.to_validator_indices()
62+
assert set(indices) == {Uint64(2), Uint64(7)}
63+
assert agg.data == att_data
64+
65+
def test_aggregated_attestation_with_many_validators(self) -> None:
66+
"""Test aggregated attestation with many validators."""
67+
att_data = AttestationData(
68+
slot=Slot(10),
69+
head=Checkpoint(root=Bytes32.zero(), slot=Slot(9)),
70+
target=Checkpoint(root=Bytes32.zero(), slot=Slot(8)),
71+
source=Checkpoint(root=Bytes32.zero(), slot=Slot(7)),
72+
)
73+
74+
validator_ids = [Uint64(i) for i in [0, 5, 10, 15, 20, 25]]
75+
bits = AggregationBits.from_validator_indices(validator_ids)
76+
agg = AggregatedAttestation(aggregation_bits=bits, data=att_data)
77+
78+
recovered = agg.aggregation_bits.to_validator_indices()
79+
assert recovered == validator_ids
80+
81+
82+
class TestAggregateByData:
83+
"""Test aggregation of plain attestations by common data."""
84+
2585
def test_aggregate_attestations_by_common_data(self) -> None:
2686
"""Test that attestations with same data are properly aggregated."""
27-
# Create three attestations with two having common data
2887
att_data1 = AttestationData(
2988
slot=Slot(5),
3089
head=Checkpoint(root=Bytes32.zero(), slot=Slot(4)),
@@ -63,30 +122,6 @@ def test_aggregate_attestations_by_common_data(self) -> None:
63122
# Should contain only validator 5
64123
assert set(validator_ids2) == {Uint64(5)}
65124

66-
def test_aggregate_attestations_sets_all_bits(self) -> None:
67-
"""Test that aggregation sets all validator bits correctly."""
68-
att_data = AttestationData(
69-
slot=Slot(5),
70-
head=Checkpoint(root=Bytes32.zero(), slot=Slot(4)),
71-
target=Checkpoint(root=Bytes32.zero(), slot=Slot(3)),
72-
source=Checkpoint(root=Bytes32.zero(), slot=Slot(2)),
73-
)
74-
75-
attestations = [
76-
Attestation(validator_id=Uint64(2), data=att_data),
77-
Attestation(validator_id=Uint64(7), data=att_data),
78-
Attestation(validator_id=Uint64(10), data=att_data),
79-
]
80-
81-
aggregated = AggregatedAttestation.aggregate_by_data(attestations)
82-
83-
assert len(aggregated) == 1
84-
validator_ids = aggregated[0].aggregation_bits.to_validator_indices()
85-
86-
# Should have all three validators
87-
assert len(validator_ids) == 3
88-
assert set(validator_ids) == {Uint64(2), Uint64(7), Uint64(10)}
89-
90125
def test_aggregate_empty_attestations(self) -> None:
91126
"""Test aggregation with no attestations."""
92127
aggregated = AggregatedAttestation.aggregate_by_data([])
@@ -108,35 +143,3 @@ def test_aggregate_single_attestation(self) -> None:
108143
assert len(aggregated) == 1
109144
validator_ids = aggregated[0].aggregation_bits.to_validator_indices()
110145
assert validator_ids == [Uint64(5)]
111-
112-
113-
class TestDuplicateAttestationDataValidation:
114-
"""Test validation that blocks don't contain duplicate AttestationData."""
115-
116-
def test_duplicate_attestation_data_detection(self) -> None:
117-
"""Ensure conversion to plain attestations preserves duplicates."""
118-
att_data = AttestationData(
119-
slot=Slot(1),
120-
head=Checkpoint(root=Bytes32.zero(), slot=Slot(0)),
121-
target=Checkpoint(root=Bytes32.zero(), slot=Slot(0)),
122-
source=Checkpoint(root=Bytes32.zero(), slot=Slot(0)),
123-
)
124-
125-
from lean_spec.subspecs.containers.attestation import AggregatedAttestation
126-
from lean_spec.subspecs.containers.attestation.types import AggregationBits
127-
128-
agg1 = AggregatedAttestation(
129-
aggregation_bits=AggregationBits(data=[False, True]),
130-
data=att_data,
131-
)
132-
agg2 = AggregatedAttestation(
133-
aggregation_bits=AggregationBits(data=[False, True, True]),
134-
data=att_data,
135-
)
136-
137-
plain = [plain_att for aggregated in (agg1, agg2) for plain_att in aggregated.to_plain()]
138-
139-
# Expect 2 plain attestations (because validator 1 is common in agg1 and agg2)
140-
# validator 1 and validator 2 are the only unique validators in the attestations
141-
assert len(set(plain)) == 2
142-
assert all(att.data == att_data for att in plain)

0 commit comments

Comments
 (0)