Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pydantic import ConfigDict, PrivateAttr, field_serializer

from lean_spec.subspecs.containers.attestation import Attestation
from lean_spec.subspecs.containers.block.block import Block, BlockBody
from lean_spec.subspecs.containers.block.types import AggregatedAttestations
from lean_spec.subspecs.containers.state.state import State
Expand Down Expand Up @@ -232,14 +233,17 @@ def _build_block_from_spec(self, spec: BlockSpec, state: State) -> tuple[Block,
return block, None

# Build the block using the state for standard case
#
# Convert aggregated attestations to plain attestations to build block
plain_attestations = [
Attestation(validator_id=vid, data=agg.data)
for agg in aggregated_attestations
for vid in agg.aggregation_bits.to_validator_indices()
]
block, post_state, _, _ = state.build_block(
slot=spec.slot,
proposer_index=proposer_index,
parent_root=parent_root,
attestations=[
attestation
for aggregated_attestation in aggregated_attestations
for attestation in aggregated_attestation.to_plain()
],
attestations=plain_attestations,
)
return block, post_state
14 changes: 0 additions & 14 deletions src/lean_spec/subspecs/containers/attestation/attestation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,6 @@ class AggregatedAttestation(Container):
committee assignments.
"""

def to_plain(self) -> list[Attestation]:
"""
Expand this aggregated attestation into plain per-validator attestations.

Returns:
One `Attestation` per participating validator index, all sharing the same
`AttestationData`.
"""
validator_indices = self.aggregation_bits.to_validator_indices()
return [
Attestation(validator_id=validator_id, data=self.data)
for validator_id in validator_indices
]

@classmethod
def aggregate_by_data(
cls,
Expand Down
17 changes: 16 additions & 1 deletion src/lean_spec/subspecs/containers/block/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from lean_spec.types import SSZList

from ...chain.config import VALIDATOR_REGISTRY_LIMIT
from ..attestation import AggregatedAttestation, NaiveAggregatedSignature
from ..attestation import AggregatedAttestation, AttestationData, NaiveAggregatedSignature


class AggregatedAttestations(SSZList):
Expand All @@ -12,6 +12,21 @@ class AggregatedAttestations(SSZList):
ELEMENT_TYPE = AggregatedAttestation
LIMIT = int(VALIDATOR_REGISTRY_LIMIT)

def __getitem__(self, index: int) -> AggregatedAttestation:
"""Access an aggregated attestation by index with proper typing."""
item = self.data[index]
assert isinstance(item, AggregatedAttestation)
return item

def has_duplicate_data(self) -> bool:
"""Check if any two attestations share the same AttestationData."""
seen: set[AttestationData] = set()
for attestation in self:
if attestation.data in seen:
return True
seen.add(attestation.data)
return False


class AttestationSignatures(SSZList):
"""List of per-attestation naive signature lists aligned with block body attestations."""
Expand Down
42 changes: 20 additions & 22 deletions src/lean_spec/subspecs/containers/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@
is_proposer,
)

from ..attestation import (
AggregatedAttestation,
Attestation,
SignedAttestation,
)
from ..attestation import AggregatedAttestation, Attestation, SignedAttestation

if TYPE_CHECKING:
from lean_spec.subspecs.xmss.containers import Signature
Expand Down Expand Up @@ -365,22 +361,23 @@ def process_block(self, block: Block) -> "State":
# First process the block header.
state = self.process_block_header(block)

# Process justification attestations by converting aggregated payloads
attestations: list[Attestation] = []
attestations_data = set()
for aggregated_att in block.body.attestations:
# No partial aggregation is allowed.
if aggregated_att.data in attestations_data:
raise AssertionError("Block contains duplicate AttestationData")

attestations_data.add(aggregated_att.data)
attestations.extend(aggregated_att.to_plain())
# Reject blocks with duplicate attestation data
#
# Each aggregated attestation in a block must refer to a unique AttestationData.
# Duplicates would allow the same vote to be counted multiple times, breaking
# the integrity of the justification tally.
#
# This is a protocol-level invariant: honest proposers never include duplicates,
# and validators must reject blocks that violate this rule.
assert not block.body.attestations.has_duplicate_data(), (
"Block contains duplicate AttestationData"
)

return state.process_attestations(attestations)
return state.process_attestations(block.body.attestations)

def process_attestations(
self,
attestations: list[Attestation],
attestations: Iterable[AggregatedAttestation],
) -> "State":
"""
Apply attestations and update justification/finalization
Expand All @@ -393,8 +390,8 @@ def process_attestations(

Parameters
----------
attestations : Attestations
The list of attestations to process.
attestations : Iterable[AggregatedAttestation]
The aggregated attestations to process.

Returns:
-------
Expand Down Expand Up @@ -507,12 +504,13 @@ def process_attestations(
if target.root not in justifications:
justifications[target.root] = [Boolean(False)] * len(self.validators)

# Mark that this validator has voted for the target.
# Mark that each validator in this aggregation has voted for the target.
#
# A vote is represented as a boolean flag.
# If it was previously absent, flip it to True.
if not justifications[target.root][attestation.validator_id]:
justifications[target.root][attestation.validator_id] = Boolean(True)
for validator_id in attestation.aggregation_bits.to_validator_indices():
if not justifications[target.root][validator_id]:
justifications[target.root][validator_id] = Boolean(True)

# Check whether the vote count crosses the supermajority threshold
#
Expand Down
12 changes: 5 additions & 7 deletions src/lean_spec/subspecs/forkchoice/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,19 +467,17 @@ def on_block(self, signed_block_with_attestation: SignedBlockWithAttestation) ->
for aggregated_attestation, aggregated_signature in zip(
aggregated_attestations, attestation_signatures, strict=True
):
plain_attestations = aggregated_attestation.to_plain()
validator_ids = aggregated_attestation.aggregation_bits.to_validator_indices()

assert len(plain_attestations) == len(aggregated_signature), (
assert len(validator_ids) == len(aggregated_signature), (
"Aggregated attestation signature count mismatch"
)

for attestation, signature in zip(
plain_attestations, aggregated_signature, strict=True
):
for validator_id, signature in zip(validator_ids, aggregated_signature, strict=True):
store = store.on_attestation(
signed_attestation=SignedAttestation(
validator_id=attestation.validator_id,
message=attestation.data,
validator_id=validator_id,
message=aggregated_attestation.data,
signature=signature,
),
is_from_block=True,
Expand Down
121 changes: 62 additions & 59 deletions tests/lean_spec/subspecs/containers/test_attestation_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,77 @@
from lean_spec.types import Bytes32, Uint64


class TestAttestationAggregation:
"""Test proper attestation aggregation by common data."""
class TestAggregationBits:
"""Test aggregation bits functionality."""

def test_reject_empty_aggregation_bits(self) -> None:
"""Validate aggregated attestation must include at least one validator."""
bits = AggregationBits(data=[False, False, False])
with pytest.raises(AssertionError, match="at least one validator"):
bits.to_validator_indices()

def test_to_validator_indices_single_bit(self) -> None:
"""Test conversion with a single bit set."""
bits = AggregationBits(data=[False, True, False])
indices = bits.to_validator_indices()
assert indices == [Uint64(1)]

def test_to_validator_indices_multiple_bits(self) -> None:
"""Test conversion with multiple bits set."""
bits = AggregationBits(data=[True, False, True, True, False])
indices = bits.to_validator_indices()
assert indices == [Uint64(0), Uint64(2), Uint64(3)]

def test_from_validator_indices_roundtrip(self) -> None:
"""Test that from_validator_indices and to_validator_indices are inverses."""
original_indices = [Uint64(1), Uint64(5), Uint64(7)]
bits = AggregationBits.from_validator_indices(original_indices)
recovered_indices = bits.to_validator_indices()
assert recovered_indices == original_indices


class TestAggregatedAttestation:
"""Test aggregated attestation structure."""

def test_aggregated_attestation_structure(self) -> None:
"""Test that aggregated attestation properly stores bits and data."""
att_data = AttestationData(
slot=Slot(5),
head=Checkpoint(root=Bytes32.zero(), slot=Slot(4)),
target=Checkpoint(root=Bytes32.zero(), slot=Slot(3)),
source=Checkpoint(root=Bytes32.zero(), slot=Slot(2)),
)

bits = AggregationBits.from_validator_indices([Uint64(2), Uint64(7)])
agg = AggregatedAttestation(aggregation_bits=bits, data=att_data)

# Verify we can extract validator indices
indices = agg.aggregation_bits.to_validator_indices()
assert set(indices) == {Uint64(2), Uint64(7)}
assert agg.data == att_data

def test_aggregated_attestation_with_many_validators(self) -> None:
"""Test aggregated attestation with many validators."""
att_data = AttestationData(
slot=Slot(10),
head=Checkpoint(root=Bytes32.zero(), slot=Slot(9)),
target=Checkpoint(root=Bytes32.zero(), slot=Slot(8)),
source=Checkpoint(root=Bytes32.zero(), slot=Slot(7)),
)

validator_ids = [Uint64(i) for i in [0, 5, 10, 15, 20, 25]]
bits = AggregationBits.from_validator_indices(validator_ids)
agg = AggregatedAttestation(aggregation_bits=bits, data=att_data)

recovered = agg.aggregation_bits.to_validator_indices()
assert recovered == validator_ids


class TestAggregateByData:
"""Test aggregation of plain attestations by common data."""

def test_aggregate_attestations_by_common_data(self) -> None:
"""Test that attestations with same data are properly aggregated."""
# Create three attestations with two having common data
att_data1 = AttestationData(
slot=Slot(5),
head=Checkpoint(root=Bytes32.zero(), slot=Slot(4)),
Expand Down Expand Up @@ -63,30 +122,6 @@ def test_aggregate_attestations_by_common_data(self) -> None:
# Should contain only validator 5
assert set(validator_ids2) == {Uint64(5)}

def test_aggregate_attestations_sets_all_bits(self) -> None:
"""Test that aggregation sets all validator bits correctly."""
att_data = AttestationData(
slot=Slot(5),
head=Checkpoint(root=Bytes32.zero(), slot=Slot(4)),
target=Checkpoint(root=Bytes32.zero(), slot=Slot(3)),
source=Checkpoint(root=Bytes32.zero(), slot=Slot(2)),
)

attestations = [
Attestation(validator_id=Uint64(2), data=att_data),
Attestation(validator_id=Uint64(7), data=att_data),
Attestation(validator_id=Uint64(10), data=att_data),
]

aggregated = AggregatedAttestation.aggregate_by_data(attestations)

assert len(aggregated) == 1
validator_ids = aggregated[0].aggregation_bits.to_validator_indices()

# Should have all three validators
assert len(validator_ids) == 3
assert set(validator_ids) == {Uint64(2), Uint64(7), Uint64(10)}

def test_aggregate_empty_attestations(self) -> None:
"""Test aggregation with no attestations."""
aggregated = AggregatedAttestation.aggregate_by_data([])
Expand All @@ -108,35 +143,3 @@ def test_aggregate_single_attestation(self) -> None:
assert len(aggregated) == 1
validator_ids = aggregated[0].aggregation_bits.to_validator_indices()
assert validator_ids == [Uint64(5)]


class TestDuplicateAttestationDataValidation:
"""Test validation that blocks don't contain duplicate AttestationData."""

def test_duplicate_attestation_data_detection(self) -> None:
"""Ensure conversion to plain attestations preserves duplicates."""
att_data = AttestationData(
slot=Slot(1),
head=Checkpoint(root=Bytes32.zero(), slot=Slot(0)),
target=Checkpoint(root=Bytes32.zero(), slot=Slot(0)),
source=Checkpoint(root=Bytes32.zero(), slot=Slot(0)),
)

from lean_spec.subspecs.containers.attestation import AggregatedAttestation
from lean_spec.subspecs.containers.attestation.types import AggregationBits

agg1 = AggregatedAttestation(
aggregation_bits=AggregationBits(data=[False, True]),
data=att_data,
)
agg2 = AggregatedAttestation(
aggregation_bits=AggregationBits(data=[False, True, True]),
data=att_data,
)

plain = [plain_att for aggregated in (agg1, agg2) for plain_att in aggregated.to_plain()]

# Expect 2 plain attestations (because validator 1 is common in agg1 and agg2)
# validator 1 and validator 2 are the only unique validators in the attestations
assert len(set(plain)) == 2
assert all(att.data == att_data for att in plain)
Loading