diff --git a/l1-contracts/src/core/libraries/crypto/SampleLib.sol b/l1-contracts/src/core/libraries/crypto/SampleLib.sol index 97bc18267ec8..f0ceb1190087 100644 --- a/l1-contracts/src/core/libraries/crypto/SampleLib.sol +++ b/l1-contracts/src/core/libraries/crypto/SampleLib.sol @@ -3,22 +3,13 @@ pragma solidity >=0.8.27; import {Errors} from "@aztec/core/libraries/Errors.sol"; -import {SlotDerivation} from "@oz/utils/SlotDerivation.sol"; -import {TransientSlot} from "@oz/utils/TransientSlot.sol"; /** * @title SampleLib * @author Anaxandridas II - * @notice A tiny library to draw committee indices using a sample without replacement algorithm. + * @notice A tiny library to draw committee indices using a sample without replacement algorithm based on Feistel permutations. */ library SampleLib { - using SlotDerivation for string; - using SlotDerivation for bytes32; - using TransientSlot for *; - - // Namespace for transient storage keys used within this library - string private constant OVERRIDE_NAMESPACE = "Aztec.SampleLib.Override"; - /** * Compute Committee * @@ -32,6 +23,7 @@ library SampleLib { */ function computeCommittee(uint256 _committeeSize, uint256 _indexCount, uint256 _seed) internal + pure returns (uint256[] memory) { require( @@ -43,73 +35,158 @@ library SampleLib { return new uint256[](0); } - uint256[] memory sampledIndices = new uint256[](_committeeSize); + // Use optimized batch Feistel computation + return computeCommitteeBatch(_committeeSize, _indexCount, _seed); + } - uint256 upperLimit = _indexCount - 1; + /** + * @notice Compute the sample index for a given index, seed and index count. + * + * @param _index - The index to shuffle + * @param _indexCount - The total number of indices + * @param _seed - The seed to use for shuffling + * + * @return shuffledIndex - The shuffled index + */ + function computeSampleIndex(uint256 _index, uint256 _indexCount, uint256 _seed) + internal + pure + returns (uint256) + { + // Cannot modulo by 0 + if (_indexCount == 0) { + return 0; + } - for (uint256 index = 0; index < _committeeSize; index++) { - uint256 sampledIndex = computeSampleIndex(index, upperLimit + 1, _seed); + return uint256(keccak256(abi.encodePacked(_seed, _index))) % _indexCount; + } - // Get index, or its swapped override - sampledIndices[index] = getValue(sampledIndex); - if (upperLimit > 0) { - // Swap with the last index - setOverrideValue(sampledIndex, getValue(upperLimit)); - // Decrement the upper limit - upperLimit--; - } - } + /** + * @notice Compute a single committee member at a specific index in O(1) time + * without allowing duplicates using a Feistel network permutation + * + * @param _index - The index in the committee (0 to _committeeSize-1) + * @param _committeeSize - The size of the committee + * @param _indexCount - The total number of validators to sample from + * @param _seed - The seed for randomization + * + * @return The validator index for the committee position + */ + function computeCommitteeMemberAtIndex( + uint256 _index, + uint256 _committeeSize, + uint256 _indexCount, + uint256 _seed + ) internal pure returns (uint256) { + require(_index < _committeeSize, Errors.SampleLib__IndexOutOfBounds(_index, _committeeSize)); + require( + _committeeSize <= _indexCount, + Errors.SampleLib__SampleLargerThanIndex(_committeeSize, _indexCount) + ); - // Clear transient storage. - // Note that we are cleaing the `sampleIndicies` and do not keep track of a separate list of - // `sampleIndex` that was written to. The reasoning being that we are only overwriting for - // duplicate cases, so `sampleIndicies` isa superset of the `sampleIndex` that have been drawn - // (due to account for duplicates). Thereby clearing the `sampleIndicies` clears all. - // Due to the cost of `tstore` and `tload` it is cheaper just to overwrite it all, than checking - // if there is even anything to override. - for (uint256 i = 0; i < _committeeSize; i++) { - setOverrideValue(sampledIndices[i], 0); - } + // Use a Feistel network to create a permutation of [0, _indexCount) + uint256 permutedIndex = feistelPermute(_index, _indexCount, _seed); - return sampledIndices; - } + // Cycle walking: keep applying the permutation until we get a value < _indexCount + // This handles non-power-of-2 sizes + while (permutedIndex >= _indexCount) { + permutedIndex = feistelPermute(permutedIndex, _indexCount, _seed); + } - function setOverrideValue(uint256 _index, uint256 _value) internal { - OVERRIDE_NAMESPACE.erc7201Slot().deriveMapping(_index).asUint256().tstore(_value); + return permutedIndex; } - function getValue(uint256 _index) internal view returns (uint256) { - uint256 overrideValue = getOverrideValue(_index); - if (overrideValue != 0) { - return overrideValue; + /** + * @notice Optimized batch computation of committee members using Feistel network + * Computes all indices simultaneously with shared pre-computation + * + * @param _committeeSize - The size of the committee + * @param _indexCount - The total number of validators to sample from + * @param _seed - The seed for randomization + * + * @return The array of validator indices for the committee + */ + function computeCommitteeBatch(uint256 _committeeSize, uint256 _indexCount, uint256 _seed) + internal + pure + returns (uint256[] memory) + { + uint256[] memory indices = new uint256[](_committeeSize); + + // Pre-compute constants for Feistel network + uint256 size = 1; + uint256 bits = 0; + while (size < _indexCount) { + size <<= 1; + bits++; } - return _index; - } + uint256 halfBits = (bits + 1) / 2; + uint256 mask = (1 << halfBits) - 1; + + // Process all committee members + for (uint256 i = 0; i < _committeeSize; i++) { + uint256 permuted = i; + + // Apply Feistel rounds + do { + uint256 left = permuted >> halfBits; + uint256 right = permuted & mask; + + // 4 rounds of Feistel + for (uint256 round = 0; round < 4; round++) { + uint256 newLeft = right; + uint256 f = uint256(keccak256(abi.encodePacked(_seed, round, right))) & mask; + right = left ^ f; + left = newLeft; + } - function getOverrideValue(uint256 _index) internal view returns (uint256) { - return OVERRIDE_NAMESPACE.erc7201Slot().deriveMapping(_index).asUint256().tload(); + permuted = (left << halfBits) | right; + } while (permuted >= _indexCount); // Cycle walking for non-power-of-2 domains + + indices[i] = permuted; + } + + return indices; } /** - * @notice Compute the sample index for a given index, seed and index count. + * @notice Feistel network implementation for creating a bijective mapping + * Guarantees no collisions within the domain * - * @param _index - The index to shuffle - * @param _indexCount - The total number of indices - * @param _seed - The seed to use for shuffling + * @param _input - The input value to permute + * @param _max - The maximum value (exclusive) in the domain + * @param _seed - The seed for the permutation * - * @return shuffledIndex - The shuffled index + * @return The permuted value */ - function computeSampleIndex(uint256 _index, uint256 _indexCount, uint256 _seed) + function feistelPermute(uint256 _input, uint256 _max, uint256 _seed) internal pure returns (uint256) { - // Cannot modulo by 0 - if (_indexCount == 0) { - return 0; + // Find next power of 2 >= _max for balanced Feistel + uint256 size = 1; + uint256 bits = 0; + while (size < _max) { + size <<= 1; + bits++; } - return uint256(keccak256(abi.encodePacked(_seed, _index))) % _indexCount; + uint256 halfBits = (bits + 1) / 2; + uint256 mask = (1 << halfBits) - 1; + + uint256 left = _input >> halfBits; + uint256 right = _input & mask; + + // 4 rounds provides good mixing for cryptographic permutation + for (uint256 round = 0; round < 4; round++) { + uint256 newLeft = right; + uint256 f = uint256(keccak256(abi.encodePacked(_seed, round, right))) & mask; + right = left ^ f; + left = newLeft; + } + + return (left << halfBits) | right; } } diff --git a/l1-contracts/src/core/libraries/rollup/ValidatorSelectionLib.sol b/l1-contracts/src/core/libraries/rollup/ValidatorSelectionLib.sol index 4632cb631756..01ff78c9e2fc 100644 --- a/l1-contracts/src/core/libraries/rollup/ValidatorSelectionLib.sol +++ b/l1-contracts/src/core/libraries/rollup/ValidatorSelectionLib.sol @@ -143,10 +143,9 @@ library ValidatorSelectionLib { return; } + uint224 sampleSeed = getSampleSeed(_epochNumber); VerifyStack memory stack = VerifyStack({ - proposerIndex: computeProposerIndex( - _epochNumber, _slot, getSampleSeed(_epochNumber), targetCommitteeSize - ), + proposerIndex: computeProposerIndex(_epochNumber, _slot, sampleSeed, targetCommitteeSize), needed: (targetCommitteeSize << 1) / 3 + 1, // targetCommitteeSize * 2 / 3 + 1, but cheaper index: 0, signaturesRecovered: 0, @@ -235,18 +234,30 @@ library ValidatorSelectionLib { } Epoch epochNumber = _slot.epochFromSlot(); + ValidatorSelectionStorage storage store = getStorage(); uint224 sampleSeed = getSampleSeed(epochNumber); - (uint32 ts, uint256[] memory indices) = sampleValidatorsIndices(epochNumber, sampleSeed); - uint256 committeeSize = indices.length; - if (committeeSize == 0) { + uint32 ts = epochToSampleTime(epochNumber); + uint256 validatorSetSize = StakingLib.getAttesterCountAtTime(Timestamp.wrap(ts)); + uint256 targetCommitteeSize = store.targetCommitteeSize; + + if (targetCommitteeSize == 0 || validatorSetSize < targetCommitteeSize) { return (address(0), 0); } - uint256 proposerIndex = computeProposerIndex(epochNumber, _slot, sampleSeed, committeeSize); - return ( - StakingLib.getAttesterFromIndexAtTime(indices[proposerIndex], Timestamp.wrap(ts)), - proposerIndex + + // Compute which committee position is the proposer for this slot + uint256 proposerIndex = + computeProposerIndex(epochNumber, _slot, sampleSeed, targetCommitteeSize); + + // Get the validator index for that committee position in O(1) time + uint256 validatorIndex = SampleLib.computeCommitteeMemberAtIndex( + proposerIndex, targetCommitteeSize, validatorSetSize, sampleSeed ); + + address proposer = StakingLib.getAttesterFromIndexAtTime(validatorIndex, Timestamp.wrap(ts)); + setCachedProposer(_slot, proposer, proposerIndex); + + return (proposer, proposerIndex); } /** diff --git a/l1-contracts/test/RollupGetters.t.sol b/l1-contracts/test/RollupGetters.t.sol index fda69d82445d..41cd03adbeee 100644 --- a/l1-contracts/test/RollupGetters.t.sol +++ b/l1-contracts/test/RollupGetters.t.sol @@ -78,6 +78,15 @@ contract RollupShouldBeGetters is ValidatorSelectionTestBase { assertEq(committeeSize, expectedSize, "invalid getCommitteeCommittmentAt size"); assertNotEq(committeeCommitment, bytes32(0), "invalid committee commitment"); + // Check for no duplicates in each committee + for (uint256 i = 0; i < expectedSize; i++) { + for (uint256 j = i + 1; j < expectedSize; j++) { + assertNotEq(committee[i], committee[j], "duplicate found in getEpochCommittee"); + assertNotEq(committee2[i], committee2[j], "duplicate found in getCommitteeAt"); + assertNotEq(committee3[i], committee3[j], "duplicate found in getCurrentEpochCommittee"); + } + } + (, bytes32[] memory writes) = vm.accesses(address(rollup)); assertEq(writes.length, 0, "No writes should be done"); } diff --git a/l1-contracts/test/validator-selection/Sampling.t.sol b/l1-contracts/test/validator-selection/Sampling.t.sol index fa322ceda8b2..c51d20808256 100644 --- a/l1-contracts/test/validator-selection/Sampling.t.sol +++ b/l1-contracts/test/validator-selection/Sampling.t.sol @@ -25,6 +25,15 @@ contract Sampler { { return SampleLib.computeSampleIndex(_index, _indexCount, _seed); } + + function computeCommitteeMemberAtIndex( + uint256 _index, + uint256 _committeeSize, + uint256 _indexCount, + uint256 _seed + ) public pure returns (uint256) { + return SampleLib.computeCommitteeMemberAtIndex(_index, _committeeSize, _indexCount, _seed); + } } contract SamplingTest is Test { @@ -99,4 +108,62 @@ contract SamplingTest is Test { // Test modulo 0 case assertEq(sampler.computeSampleIndex(_index, 0, _seed), 0); } + + function testConstantTimeCommitteeMember() public { + uint256 committeeSize = 48; + uint256 indexCount = 1000; + uint256 seed = 12345; + + // Test that we get the same results as the original algorithm + uint256[] memory fullCommittee = sampler.computeCommittee(committeeSize, indexCount, seed); + + for (uint256 i = 0; i < committeeSize; i++) { + uint256 memberAtIndex = sampler.computeCommitteeMemberAtIndex(i, committeeSize, indexCount, seed); + assertEq(memberAtIndex, fullCommittee[i], "Member mismatch at index"); + } + } + + function testConstantTimeNoDuplicates(uint8 _committeeSize, uint16 _indexCount, uint256 _seed) public { + vm.assume(_committeeSize <= _indexCount); + vm.assume(_committeeSize > 0 && _committeeSize <= 100); // Reasonable bounds for testing + vm.assume(_indexCount > 0 && _indexCount <= 1000); + vm.assume(_seed != 0); + + // Get all members using constant-time function + uint256[] memory members = new uint256[](_committeeSize); + for (uint256 i = 0; i < _committeeSize; i++) { + members[i] = sampler.computeCommitteeMemberAtIndex(i, _committeeSize, _indexCount, _seed); + } + + // Check no duplicates + for (uint256 i = 0; i < _committeeSize; i++) { + for (uint256 j = i + 1; j < _committeeSize; j++) { + assertNotEq(members[i], members[j], "Duplicate found"); + } + } + } + + function testConstantTimeOutOfBounds() public { + uint256 committeeSize = 10; + uint256 indexCount = 100; + + vm.expectRevert( + abi.encodeWithSelector( + Errors.SampleLib__IndexOutOfBounds.selector, 10, committeeSize + ) + ); + sampler.computeCommitteeMemberAtIndex(10, committeeSize, indexCount, 1234); + } + + function testConstantTimeCommitteeTooLarge() public { + uint256 committeeSize = 101; + uint256 indexCount = 100; + + vm.expectRevert( + abi.encodeWithSelector( + Errors.SampleLib__SampleLargerThanIndex.selector, committeeSize, indexCount + ) + ); + sampler.computeCommitteeMemberAtIndex(0, committeeSize, indexCount, 1234); + } }