diff --git a/packages/testing/src/consensus_testing/test_fixtures/state_transition.py b/packages/testing/src/consensus_testing/test_fixtures/state_transition.py index e24de23d..892b0b27 100644 --- a/packages/testing/src/consensus_testing/test_fixtures/state_transition.py +++ b/packages/testing/src/consensus_testing/test_fixtures/state_transition.py @@ -4,6 +4,7 @@ from pydantic import ConfigDict, PrivateAttr, field_serializer +from lean_spec.subspecs.containers.attestation import Attestation from lean_spec.subspecs.containers.block.block import Block, BlockBody from lean_spec.subspecs.containers.block.types import AggregatedAttestations from lean_spec.subspecs.containers.state.state import State @@ -232,14 +233,17 @@ def _build_block_from_spec(self, spec: BlockSpec, state: State) -> tuple[Block, return block, None # Build the block using the state for standard case + # + # Convert aggregated attestations to plain attestations to build block + plain_attestations = [ + Attestation(validator_id=vid, data=agg.data) + for agg in aggregated_attestations + for vid in agg.aggregation_bits.to_validator_indices() + ] block, post_state, _, _ = state.build_block( slot=spec.slot, proposer_index=proposer_index, parent_root=parent_root, - attestations=[ - attestation - for aggregated_attestation in aggregated_attestations - for attestation in aggregated_attestation.to_plain() - ], + attestations=plain_attestations, ) return block, post_state diff --git a/src/lean_spec/subspecs/containers/attestation/attestation.py b/src/lean_spec/subspecs/containers/attestation/attestation.py index c6614251..dedc6f7d 100644 --- a/src/lean_spec/subspecs/containers/attestation/attestation.py +++ b/src/lean_spec/subspecs/containers/attestation/attestation.py @@ -81,20 +81,6 @@ class AggregatedAttestation(Container): committee assignments. """ - def to_plain(self) -> list[Attestation]: - """ - Expand this aggregated attestation into plain per-validator attestations. - - Returns: - One `Attestation` per participating validator index, all sharing the same - `AttestationData`. - """ - validator_indices = self.aggregation_bits.to_validator_indices() - return [ - Attestation(validator_id=validator_id, data=self.data) - for validator_id in validator_indices - ] - @classmethod def aggregate_by_data( cls, diff --git a/src/lean_spec/subspecs/containers/block/types.py b/src/lean_spec/subspecs/containers/block/types.py index e602ef20..6135e770 100644 --- a/src/lean_spec/subspecs/containers/block/types.py +++ b/src/lean_spec/subspecs/containers/block/types.py @@ -3,7 +3,7 @@ from lean_spec.types import SSZList from ...chain.config import VALIDATOR_REGISTRY_LIMIT -from ..attestation import AggregatedAttestation, NaiveAggregatedSignature +from ..attestation import AggregatedAttestation, AttestationData, NaiveAggregatedSignature class AggregatedAttestations(SSZList): @@ -12,6 +12,21 @@ class AggregatedAttestations(SSZList): ELEMENT_TYPE = AggregatedAttestation LIMIT = int(VALIDATOR_REGISTRY_LIMIT) + def __getitem__(self, index: int) -> AggregatedAttestation: + """Access an aggregated attestation by index with proper typing.""" + item = self.data[index] + assert isinstance(item, AggregatedAttestation) + return item + + def has_duplicate_data(self) -> bool: + """Check if any two attestations share the same AttestationData.""" + seen: set[AttestationData] = set() + for attestation in self: + if attestation.data in seen: + return True + seen.add(attestation.data) + return False + class AttestationSignatures(SSZList): """List of per-attestation naive signature lists aligned with block body attestations.""" diff --git a/src/lean_spec/subspecs/containers/state/state.py b/src/lean_spec/subspecs/containers/state/state.py index b0cfc728..e1654a59 100644 --- a/src/lean_spec/subspecs/containers/state/state.py +++ b/src/lean_spec/subspecs/containers/state/state.py @@ -12,11 +12,7 @@ is_proposer, ) -from ..attestation import ( - AggregatedAttestation, - Attestation, - SignedAttestation, -) +from ..attestation import AggregatedAttestation, Attestation, SignedAttestation if TYPE_CHECKING: from lean_spec.subspecs.xmss.containers import Signature @@ -365,22 +361,23 @@ def process_block(self, block: Block) -> "State": # First process the block header. state = self.process_block_header(block) - # Process justification attestations by converting aggregated payloads - attestations: list[Attestation] = [] - attestations_data = set() - for aggregated_att in block.body.attestations: - # No partial aggregation is allowed. - if aggregated_att.data in attestations_data: - raise AssertionError("Block contains duplicate AttestationData") - - attestations_data.add(aggregated_att.data) - attestations.extend(aggregated_att.to_plain()) + # Reject blocks with duplicate attestation data + # + # Each aggregated attestation in a block must refer to a unique AttestationData. + # Duplicates would allow the same vote to be counted multiple times, breaking + # the integrity of the justification tally. + # + # This is a protocol-level invariant: honest proposers never include duplicates, + # and validators must reject blocks that violate this rule. + assert not block.body.attestations.has_duplicate_data(), ( + "Block contains duplicate AttestationData" + ) - return state.process_attestations(attestations) + return state.process_attestations(block.body.attestations) def process_attestations( self, - attestations: list[Attestation], + attestations: Iterable[AggregatedAttestation], ) -> "State": """ Apply attestations and update justification/finalization @@ -393,8 +390,8 @@ def process_attestations( Parameters ---------- - attestations : Attestations - The list of attestations to process. + attestations : Iterable[AggregatedAttestation] + The aggregated attestations to process. Returns: ------- @@ -507,12 +504,13 @@ def process_attestations( if target.root not in justifications: justifications[target.root] = [Boolean(False)] * len(self.validators) - # Mark that this validator has voted for the target. + # Mark that each validator in this aggregation has voted for the target. # # A vote is represented as a boolean flag. # If it was previously absent, flip it to True. - if not justifications[target.root][attestation.validator_id]: - justifications[target.root][attestation.validator_id] = Boolean(True) + for validator_id in attestation.aggregation_bits.to_validator_indices(): + if not justifications[target.root][validator_id]: + justifications[target.root][validator_id] = Boolean(True) # Check whether the vote count crosses the supermajority threshold # diff --git a/src/lean_spec/subspecs/forkchoice/store.py b/src/lean_spec/subspecs/forkchoice/store.py index 51e754fb..b7db68c5 100644 --- a/src/lean_spec/subspecs/forkchoice/store.py +++ b/src/lean_spec/subspecs/forkchoice/store.py @@ -467,19 +467,17 @@ def on_block(self, signed_block_with_attestation: SignedBlockWithAttestation) -> for aggregated_attestation, aggregated_signature in zip( aggregated_attestations, attestation_signatures, strict=True ): - plain_attestations = aggregated_attestation.to_plain() + validator_ids = aggregated_attestation.aggregation_bits.to_validator_indices() - assert len(plain_attestations) == len(aggregated_signature), ( + assert len(validator_ids) == len(aggregated_signature), ( "Aggregated attestation signature count mismatch" ) - for attestation, signature in zip( - plain_attestations, aggregated_signature, strict=True - ): + for validator_id, signature in zip(validator_ids, aggregated_signature, strict=True): store = store.on_attestation( signed_attestation=SignedAttestation( - validator_id=attestation.validator_id, - message=attestation.data, + validator_id=validator_id, + message=aggregated_attestation.data, signature=signature, ), is_from_block=True, diff --git a/tests/lean_spec/subspecs/containers/test_attestation_aggregation.py b/tests/lean_spec/subspecs/containers/test_attestation_aggregation.py index 1d5e2e13..c42e5628 100644 --- a/tests/lean_spec/subspecs/containers/test_attestation_aggregation.py +++ b/tests/lean_spec/subspecs/containers/test_attestation_aggregation.py @@ -13,8 +13,8 @@ from lean_spec.types import Bytes32, Uint64 -class TestAttestationAggregation: - """Test proper attestation aggregation by common data.""" +class TestAggregationBits: + """Test aggregation bits functionality.""" def test_reject_empty_aggregation_bits(self) -> None: """Validate aggregated attestation must include at least one validator.""" @@ -22,9 +22,68 @@ def test_reject_empty_aggregation_bits(self) -> None: with pytest.raises(AssertionError, match="at least one validator"): bits.to_validator_indices() + def test_to_validator_indices_single_bit(self) -> None: + """Test conversion with a single bit set.""" + bits = AggregationBits(data=[False, True, False]) + indices = bits.to_validator_indices() + assert indices == [Uint64(1)] + + def test_to_validator_indices_multiple_bits(self) -> None: + """Test conversion with multiple bits set.""" + bits = AggregationBits(data=[True, False, True, True, False]) + indices = bits.to_validator_indices() + assert indices == [Uint64(0), Uint64(2), Uint64(3)] + + def test_from_validator_indices_roundtrip(self) -> None: + """Test that from_validator_indices and to_validator_indices are inverses.""" + original_indices = [Uint64(1), Uint64(5), Uint64(7)] + bits = AggregationBits.from_validator_indices(original_indices) + recovered_indices = bits.to_validator_indices() + assert recovered_indices == original_indices + + +class TestAggregatedAttestation: + """Test aggregated attestation structure.""" + + def test_aggregated_attestation_structure(self) -> None: + """Test that aggregated attestation properly stores bits and data.""" + att_data = AttestationData( + slot=Slot(5), + head=Checkpoint(root=Bytes32.zero(), slot=Slot(4)), + target=Checkpoint(root=Bytes32.zero(), slot=Slot(3)), + source=Checkpoint(root=Bytes32.zero(), slot=Slot(2)), + ) + + bits = AggregationBits.from_validator_indices([Uint64(2), Uint64(7)]) + agg = AggregatedAttestation(aggregation_bits=bits, data=att_data) + + # Verify we can extract validator indices + indices = agg.aggregation_bits.to_validator_indices() + assert set(indices) == {Uint64(2), Uint64(7)} + assert agg.data == att_data + + def test_aggregated_attestation_with_many_validators(self) -> None: + """Test aggregated attestation with many validators.""" + att_data = AttestationData( + slot=Slot(10), + head=Checkpoint(root=Bytes32.zero(), slot=Slot(9)), + target=Checkpoint(root=Bytes32.zero(), slot=Slot(8)), + source=Checkpoint(root=Bytes32.zero(), slot=Slot(7)), + ) + + validator_ids = [Uint64(i) for i in [0, 5, 10, 15, 20, 25]] + bits = AggregationBits.from_validator_indices(validator_ids) + agg = AggregatedAttestation(aggregation_bits=bits, data=att_data) + + recovered = agg.aggregation_bits.to_validator_indices() + assert recovered == validator_ids + + +class TestAggregateByData: + """Test aggregation of plain attestations by common data.""" + def test_aggregate_attestations_by_common_data(self) -> None: """Test that attestations with same data are properly aggregated.""" - # Create three attestations with two having common data att_data1 = AttestationData( slot=Slot(5), head=Checkpoint(root=Bytes32.zero(), slot=Slot(4)), @@ -63,30 +122,6 @@ def test_aggregate_attestations_by_common_data(self) -> None: # Should contain only validator 5 assert set(validator_ids2) == {Uint64(5)} - def test_aggregate_attestations_sets_all_bits(self) -> None: - """Test that aggregation sets all validator bits correctly.""" - att_data = AttestationData( - slot=Slot(5), - head=Checkpoint(root=Bytes32.zero(), slot=Slot(4)), - target=Checkpoint(root=Bytes32.zero(), slot=Slot(3)), - source=Checkpoint(root=Bytes32.zero(), slot=Slot(2)), - ) - - attestations = [ - Attestation(validator_id=Uint64(2), data=att_data), - Attestation(validator_id=Uint64(7), data=att_data), - Attestation(validator_id=Uint64(10), data=att_data), - ] - - aggregated = AggregatedAttestation.aggregate_by_data(attestations) - - assert len(aggregated) == 1 - validator_ids = aggregated[0].aggregation_bits.to_validator_indices() - - # Should have all three validators - assert len(validator_ids) == 3 - assert set(validator_ids) == {Uint64(2), Uint64(7), Uint64(10)} - def test_aggregate_empty_attestations(self) -> None: """Test aggregation with no attestations.""" aggregated = AggregatedAttestation.aggregate_by_data([]) @@ -108,35 +143,3 @@ def test_aggregate_single_attestation(self) -> None: assert len(aggregated) == 1 validator_ids = aggregated[0].aggregation_bits.to_validator_indices() assert validator_ids == [Uint64(5)] - - -class TestDuplicateAttestationDataValidation: - """Test validation that blocks don't contain duplicate AttestationData.""" - - def test_duplicate_attestation_data_detection(self) -> None: - """Ensure conversion to plain attestations preserves duplicates.""" - att_data = AttestationData( - slot=Slot(1), - head=Checkpoint(root=Bytes32.zero(), slot=Slot(0)), - target=Checkpoint(root=Bytes32.zero(), slot=Slot(0)), - source=Checkpoint(root=Bytes32.zero(), slot=Slot(0)), - ) - - from lean_spec.subspecs.containers.attestation import AggregatedAttestation - from lean_spec.subspecs.containers.attestation.types import AggregationBits - - agg1 = AggregatedAttestation( - aggregation_bits=AggregationBits(data=[False, True]), - data=att_data, - ) - agg2 = AggregatedAttestation( - aggregation_bits=AggregationBits(data=[False, True, True]), - data=att_data, - ) - - plain = [plain_att for aggregated in (agg1, agg2) for plain_att in aggregated.to_plain()] - - # Expect 2 plain attestations (because validator 1 is common in agg1 and agg2) - # validator 1 and validator 2 are the only unique validators in the attestations - assert len(set(plain)) == 2 - assert all(att.data == att_data for att in plain)