diff --git a/README.md b/README.md index 6332f5867a..86fe13e998 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ utils ├─ LibZip — "Library for compressing and decompressing bytes" ├─ Lifebuoy — "Class that allows for rescue of ETH, ERC20, ERC721 tokens" ├─ MerkleProofLib — "Library for verification of Merkle proofs" +├─ MerkleTreeLib — "Library for generating Merkle trees" ├─ MetadataReaderLib — "Library for reading contract metadata robustly" ├─ MinHeapLib — "Library for managing a min-heap in storage or memory" ├─ Multicallable — "Contract that enables a single call to call multiple methods on itself" diff --git a/docs/utils/merkletreelib.md b/docs/utils/merkletreelib.md new file mode 100644 index 0000000000..d8025e7b38 --- /dev/null +++ b/docs/utils/merkletreelib.md @@ -0,0 +1,135 @@ +# MerkleTreeLib + +Library for generating Merkle trees. + + + + + + + + +## Custom Errors + +### MerkleTreeLeafsEmpty() + +```solidity +error MerkleTreeLeafsEmpty() +``` + +At least 1 leaf is required to build the tree. + +### MerkleTreeOutOfBoundsAccess() + +```solidity +error MerkleTreeOutOfBoundsAccess() +``` + +Attempt to access a node with an out-of-bounds index. +Check if the tree has been built and has sufficient leafs and nodes. + +### MerkleTreeInvalidLeafIndices() + +```solidity +error MerkleTreeInvalidLeafIndices() +``` + +Leaf indices for multi proof must be strictly ascending and not empty. + +## Merkle Tree Operations + +### build(bytes32[]) + +```solidity +function build(bytes32[] memory leafs) + internal + pure + returns (bytes32[] memory result) +``` + +Builds and return a complete Merkle tree. +To make it a full Merkle tree, use `build(pad(leafs))`. + +### root(bytes32[]) + +```solidity +function root(bytes32[] memory t) internal pure returns (bytes32 result) +``` + +Returns the root. + +### numLeafs(bytes32[]) + +```solidity +function numLeafs(bytes32[] memory t) internal pure returns (uint256) +``` + +Returns the number of leafs. + +### numInternalNodes(bytes32[]) + +```solidity +function numInternalNodes(bytes32[] memory t) + internal + pure + returns (uint256) +``` + +Returns the number of internal nodes. + +### leaf(bytes32[],uint256) + +```solidity +function leaf(bytes32[] memory t, uint256 leafIndex) + internal + pure + returns (bytes32 result) +``` + +Returns the leaf at `leafIndex`. + +### leafProof(bytes32[],uint256) + +```solidity +function leafProof(bytes32[] memory t, uint256 leafIndex) + internal + pure + returns (bytes32[] memory result) +``` + +Returns the proof for the leaf at `leafIndex`. + +### nodeProof(bytes32[],uint256) + +```solidity +function nodeProof(bytes32[] memory t, uint256 nodeIndex) + internal + pure + returns (bytes32[] memory result) +``` + +Returns the proof for the node at `nodeIndex`. +This function can be used to prove the existence of internal nodes. + +### leafsMultiProof(bytes32[],uint256[]) + +```solidity +function leafsMultiProof(bytes32[] memory t, uint256[] memory leafIndices) + internal + pure + returns (bytes32[] memory proof, bool[] memory flags) +``` + +Returns proof and corresponding flags for multiple leafs. +The `leafIndices` must be non-empty and sorted in strictly ascending order. + +### pad(bytes32[],bytes32) + +```solidity +function pad(bytes32[] memory leafs, bytes32 defaultFill) + internal + pure + returns (bytes32[] memory result) +``` + +Returns a copy of leafs, with the length padded to a power of 2. \ No newline at end of file diff --git a/src/Milady.sol b/src/Milady.sol index 6aef5a3124..2e14615553 100644 --- a/src/Milady.sol +++ b/src/Milady.sol @@ -53,6 +53,7 @@ import "./utils/LibString.sol"; import "./utils/LibZip.sol"; import "./utils/Lifebuoy.sol"; import "./utils/MerkleProofLib.sol"; +import "./utils/MerkleTreeLib.sol"; import "./utils/MetadataReaderLib.sol"; import "./utils/MinHeapLib.sol"; import "./utils/Multicallable.sol"; diff --git a/src/utils/MerkleTreeLib.sol b/src/utils/MerkleTreeLib.sol new file mode 100644 index 0000000000..e898f93709 --- /dev/null +++ b/src/utils/MerkleTreeLib.sol @@ -0,0 +1,247 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.4; + +/// @notice Library for generating Merkle trees. +/// @author Solady (https://github.com/vectorized/solady/blob/main/src/utils/MerkleTreeLib.sol) +/// @author Modified from OpenZeppelin (https://github.com/OpenZeppelin/merkle-tree/blob/master/src/core.ts) +library MerkleTreeLib { + /*´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:*/ + /* CUSTOM ERRORS */ + /*.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•*/ + + /// @dev At least 1 leaf is required to build the tree. + error MerkleTreeLeafsEmpty(); + + /// @dev Attempt to access a node with an out-of-bounds index. + /// Check if the tree has been built and has sufficient leafs and nodes. + error MerkleTreeOutOfBoundsAccess(); + + /// @dev Leaf indices for multi proof must be strictly ascending and not empty. + error MerkleTreeInvalidLeafIndices(); + + /*´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:*/ + /* MERKLE TREE OPERATIONS */ + /*.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•*/ + + /// @dev Builds and return a complete Merkle tree. + /// To make it a full Merkle tree, use `build(pad(leafs))`. + function build(bytes32[] memory leafs) internal pure returns (bytes32[] memory result) { + /// @solidity memory-safe-assembly + assembly { + result := mload(0x40) // `nodes`. + let l := mload(leafs) + if iszero(l) { + mstore(0x00, 0x089aff6e) // `MerkleTreeLeafsEmpty()`. + revert(0x1c, 0x04) + } + let n := sub(add(l, l), 1) + mstore(result, n) // `.length`. + let nodes := add(result, 0x20) + let f := add(nodes, shl(5, n)) + mstore(0x40, f) // Allocate memory. + let e := add(0x20, shl(5, l)) + for { let i := 0x20 } 1 {} { + mstore(sub(f, i), mload(add(leafs, i))) + i := add(i, 0x20) + if eq(i, e) { break } + } + if iszero(lt(l, 2)) { + for { let i := shl(5, sub(l, 2)) } 1 {} { + let left := mload(add(nodes, add(add(i, i), 0x20))) + let right := mload(add(nodes, add(add(i, i), 0x40))) + let c := shl(5, lt(left, right)) + mstore(c, right) + mstore(xor(c, 0x20), left) + mstore(add(nodes, i), keccak256(0x00, 0x40)) + if iszero(i) { break } + i := sub(i, 0x20) + } + } + } + } + + /// @dev Returns the root. + function root(bytes32[] memory t) internal pure returns (bytes32 result) { + /// @solidity memory-safe-assembly + assembly { + result := mload(add(0x20, t)) + if iszero(mload(t)) { + mstore(0x00, 0x7a856a38) // `MerkleTreeOutOfBoundsAccess()`. + revert(0x1c, 0x04) + } + } + } + + /// @dev Returns the number of leafs. + function numLeafs(bytes32[] memory t) internal pure returns (uint256) { + unchecked { + return t.length - (t.length >> 1); + } + } + + /// @dev Returns the number of internal nodes. + function numInternalNodes(bytes32[] memory t) internal pure returns (uint256) { + return t.length >> 1; + } + + /// @dev Returns the leaf at `leafIndex`. + function leaf(bytes32[] memory t, uint256 leafIndex) internal pure returns (bytes32 result) { + /// @solidity memory-safe-assembly + assembly { + let n := mload(t) + if iszero(lt(leafIndex, sub(n, shr(1, n)))) { + mstore(0x00, 0x7a856a38) // `MerkleTreeOutOfBoundsAccess()`. + revert(0x1c, 0x04) + } + result := mload(add(t, shl(5, sub(n, leafIndex)))) + } + } + + /// @dev Returns the proof for the leaf at `leafIndex`. + function leafProof(bytes32[] memory t, uint256 leafIndex) + internal + pure + returns (bytes32[] memory result) + { + /// @solidity memory-safe-assembly + assembly { + result := mload(0x40) + let n := mload(t) + if iszero(lt(leafIndex, sub(n, shr(1, n)))) { + mstore(0x00, 0x7a856a38) // `MerkleTreeOutOfBoundsAccess()`. + revert(0x1c, 0x04) + } + let o := add(result, 0x20) + for { let i := sub(n, add(1, leafIndex)) } i { i := shr(1, sub(i, 1)) } { + mstore(o, mload(add(t, shl(5, add(i, shl(1, and(1, i))))))) + o := add(o, 0x20) + } + mstore(0x40, o) // Allocate memory. + mstore(result, shr(5, sub(o, add(result, 0x20)))) // Store length. + } + } + + /// @dev Returns the proof for the node at `nodeIndex`. + /// This function can be used to prove the existence of internal nodes. + function nodeProof(bytes32[] memory t, uint256 nodeIndex) + internal + pure + returns (bytes32[] memory result) + { + /// @solidity memory-safe-assembly + assembly { + result := mload(0x40) + if iszero(lt(nodeIndex, mload(t))) { + mstore(0x00, 0x7a856a38) // `MerkleTreeOutOfBoundsAccess()`. + revert(0x1c, 0x04) + } + let o := add(result, 0x20) + for { let i := nodeIndex } i { i := shr(1, sub(i, 1)) } { + mstore(o, mload(add(t, shl(5, add(i, shl(1, and(1, i))))))) + o := add(o, 0x20) + } + mstore(0x40, o) // Allocate memory. + mstore(result, shr(5, sub(o, add(result, 0x20)))) // Store length. + } + } + + /// @dev Returns proof and corresponding flags for multiple leafs. + /// The `leafIndices` must be non-empty and sorted in strictly ascending order. + function leafsMultiProof(bytes32[] memory t, uint256[] memory leafIndices) + internal + pure + returns (bytes32[] memory proof, bool[] memory flags) + { + /// @solidity memory-safe-assembly + assembly { + function gen(leafIndices_, t_, proof_, flags_) -> _flagsLen, _proofLen { + let q_ := mload(0x40) // Circular buffer. + let c_ := mload(leafIndices_) // Capacity of circular buffer. + let e_ := mload(leafIndices_) // End index of circular buffer. + let b_ := 0 // Start index of circular buffer. + if iszero(e_) { + mstore(0x00, 0xe9729976) // `MerkleTreeInvalidLeafIndices()`. + revert(0x1c, 0x04) + } + for { + let n_ := mload(t_) // Num nodes. + let l_ := sub(n_, shr(1, n_)) // Num leafs. + let p_ := 0 + let i_ := 0 + } 1 {} { + let j_ := mload(add(add(leafIndices_, 0x20), shl(5, i_))) // Leaf index. + if flags_ { + if iszero(lt(j_, l_)) { + mstore(0x00, 0x7a856a38) // `MerkleTreeOutOfBoundsAccess()`. + revert(0x1c, 0x04) + } + if iszero(or(iszero(i_), gt(j_, p_))) { + mstore(0x00, 0xe9729976) // `MerkleTreeInvalidLeafIndices()`. + revert(0x1c, 0x04) + } + p_ := j_ + } + mstore(add(q_, shl(5, i_)), sub(n_, add(1, j_))) + i_ := add(i_, 1) + if eq(i_, e_) { break } + } + for {} 1 {} { + if iszero(lt(b_, e_)) { break } + let j_ := mload(add(q_, shl(5, mod(b_, c_)))) // Current. + if iszero(j_) { break } + b_ := add(b_, 1) + let s_ := add(j_, shl(1, and(j_, 1))) // Sibling (+1). + _flagsLen := add(_flagsLen, 0x20) + let f_ := and(eq(s_, add(1, mload(add(q_, shl(5, mod(b_, c_)))))), lt(b_, e_)) + b_ := add(b_, f_) + if flags_ { mstore(add(flags_, _flagsLen), f_) } + if iszero(f_) { + _proofLen := add(_proofLen, 0x20) + if flags_ { mstore(add(proof_, _proofLen), mload(add(t_, shl(5, s_)))) } + } + mstore(add(q_, shl(5, mod(e_, c_))), shr(1, sub(j_, 1))) + e_ := add(e_, 1) + } + _proofLen := shr(5, _proofLen) + _flagsLen := shr(5, _flagsLen) + } + let flagsLen, proofLen := gen(leafIndices, t, 0x00, 0x00) + proof := mload(0x40) + mstore(proof, proofLen) + flags := add(add(proof, 0x20), shl(5, proofLen)) + mstore(flags, flagsLen) + mstore(0x40, add(add(flags, 0x20), shl(5, flagsLen))) // Allocate memory. + flagsLen, proofLen := gen(leafIndices, t, proof, flags) + } + } + + /// @dev Returns a copy of leafs, with the length padded to a power of 2. + function pad(bytes32[] memory leafs, bytes32 defaultFill) + internal + pure + returns (bytes32[] memory result) + { + /// @solidity memory-safe-assembly + assembly { + result := mload(0x40) + let l := mload(leafs) + if iszero(l) { + mstore(0x00, 0x089aff6e) // `MerkleTreeLeafsEmpty()`. + revert(0x1c, 0x04) + } + let p := 1 // Padded length. + for {} lt(p, l) {} { p := add(p, p) } + mstore(result, p) // Store length. + mstore(0x40, add(result, add(0x20, shl(5, p)))) // Allocate memory. + let d := sub(result, leafs) + let copyEnd := add(add(leafs, 0x20), shl(5, l)) + let end := add(add(leafs, 0x20), shl(5, p)) + mstore(0x00, defaultFill) + for { let i := add(leafs, 0x20) } 1 {} { + mstore(add(i, d), mload(mul(i, lt(i, copyEnd)))) + i := add(i, 0x20) + if eq(i, end) { break } + } + } + } +} diff --git a/test/MerkleTreeLib.t.sol b/test/MerkleTreeLib.t.sol new file mode 100644 index 0000000000..7006490efd --- /dev/null +++ b/test/MerkleTreeLib.t.sol @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.4; + +import "./utils/SoladyTest.sol"; +import {MerkleTreeLib} from "../src/utils/MerkleTreeLib.sol"; +import {MerkleProofLib} from "../src/utils/MerkleProofLib.sol"; +import {LibSort} from "../src/utils/LibSort.sol"; +import {LibPRNG} from "../src/utils/LibPRNG.sol"; +import {EfficientHashLib} from "../src/utils/EfficientHashLib.sol"; + +contract MerkleTreeLibTest is SoladyTest { + using MerkleTreeLib for bytes32[]; + using LibPRNG for *; + + function testBuildCompleteMerkleTree(bytes32[] memory leafs, bytes32 r) public { + _maybeBrutalizeMemory(r); + if (leafs.length <= 1) { + leafs = new bytes32[](1); + leafs[0] = r; + } + bytes32[] memory t = MerkleTreeLib.build(leafs); + assertEq(t.length, leafs.length * 2 - 1); + if (leafs.length == 1) { + assertEq(t[0], r); + } else { + assertNotEq(t[0], 0); + } + assertEq(t.root(), t[0]); + assertEq(leafs.length, t.numLeafs()); + assertEq(t.length, t.numLeafs() + t.numInternalNodes()); + _checkMemory(t); + if (leafs.length >= 1) { + uint256 i = _randomUniform() % leafs.length; + assertEq(t.leaf(i), leafs[i]); + } + } + + function testPad(bytes32[] memory leafs, bytes32 defaultFill, uint256 r) public { + _maybeBrutalizeMemory(r); + if (leafs.length == 0) return; + assertEq(MerkleTreeLib.pad(leafs, defaultFill), _padOriginal(leafs, defaultFill)); + _checkMemory(); + } + + function _padOriginal(bytes32[] memory leafs, bytes32 defaultFill) + internal + pure + returns (bytes32[] memory result) + { + unchecked { + uint256 p = 1; + while (p < leafs.length) p = p << 1; + result = new bytes32[](p); + for (uint256 i; i < p; ++i) { + if (i < leafs.length) { + result[i] = leafs[i]; + } else { + result[i] = defaultFill; + } + } + } + } + + function _maybeBrutalizeMemory(uint256 r) internal view { + _maybeBrutalizeMemory(bytes32(r)); + } + + function _maybeBrutalizeMemory(bytes32 r) internal view { + uint256 h = uint256(EfficientHashLib.hash(r, "hehe")); + if (h & 0xf0 == 0) _misalignFreeMemoryPointer(); + if (h & 0x0f == 0) _brutalizeMemory(); + } + + function testBuildAndGetLeaf(bytes32[] memory leafs, uint256 leafIndex) public { + if (leafs.length == 0) return; + + if (leafIndex < leafs.length) { + assertEq(this.buildAndGetLeaf(leafs, leafIndex), leafs[leafIndex]); + } else { + vm.expectRevert(MerkleTreeLib.MerkleTreeOutOfBoundsAccess.selector); + this.buildAndGetLeaf(leafs, leafIndex); + } + } + + function buildAndGetLeaf(bytes32[] memory leafs, uint256 leafIndex) + public + pure + returns (bytes32) + { + return MerkleTreeLib.build(leafs).leaf(leafIndex); + } + + function testBuildAndGetLeafProof(bytes32[] memory leafs, uint256 leafIndex) public { + if (leafs.length == 0) return _testBuildAndGetRoot(leafs); + bytes32[] memory t = MerkleTreeLib.build(leafs); + if (leafIndex < leafs.length) { + bytes32[] memory proof = this.buildAndGetLeafProof(leafs, leafIndex); + assertTrue(MerkleProofLib.verify(proof, t.root(), leafs[leafIndex])); + } else { + vm.expectRevert(MerkleTreeLib.MerkleTreeOutOfBoundsAccess.selector); + this.buildAndGetLeafProof(leafs, leafIndex); + } + } + + function buildAndGetLeafProof(bytes32[] memory leafs, uint256 leafIndex) + public + pure + returns (bytes32[] memory proof) + { + bytes32[] memory t = MerkleTreeLib.build(leafs); + proof = t.leafProof(leafIndex); + _checkMemory(); + } + + function testBuildAndGetNodeProof(bytes32[] memory leafs, uint256 nodeIndex) public { + if (leafs.length == 0) return _testBuildAndGetRoot(leafs); + bytes32[] memory t = MerkleTreeLib.build(leafs); + if (nodeIndex < t.length) { + bytes32[] memory proof = this.buildAndGetNodeProof(leafs, nodeIndex); + assertTrue(MerkleProofLib.verify(proof, t.root(), t[nodeIndex])); + } else { + vm.expectRevert(MerkleTreeLib.MerkleTreeOutOfBoundsAccess.selector); + this.buildAndGetNodeProof(leafs, nodeIndex); + } + } + + function buildAndGetNodeProof(bytes32[] memory leafs, uint256 nodeIndex) + public + pure + returns (bytes32[] memory proof) + { + bytes32[] memory t = MerkleTreeLib.build(leafs); + proof = t.nodeProof(nodeIndex); + _checkMemory(); + } + + function _testBuildAndGetRoot(bytes32[] memory leafs) internal { + vm.expectRevert(MerkleTreeLib.MerkleTreeLeafsEmpty.selector); + this.buildAndGetRoot(leafs); + } + + function buildAndGetRoot(bytes32[] memory leafs) public pure returns (bytes32) { + return MerkleTreeLib.build(leafs).root(); + } + + function testGetRootFromEmptyTree() public { + vm.expectRevert(MerkleTreeLib.MerkleTreeOutOfBoundsAccess.selector); + this.getRootFromEmptyTree(); + } + + function getRootFromEmptyTree() public pure returns (bytes32) { + return (new bytes32[](0)).root(); + } + + struct TestMultiProofTemps { + bytes32[] leafs; + uint256[] leafIndices; + bytes32[] gathered; + bytes32[] tree; + bytes32[] proof; + bool[] flags; + } + + function testBuildAndGetLeafsMultiProof(bytes32) public { + TestMultiProofTemps memory t; + t.leafs = new bytes32[](_bound(_random(), 1, 128)); + for (uint256 i; i < t.leafs.length; ++i) { + t.leafs[i] = bytes32(_random()); + } + t.leafIndices = _generateUniqueLeafIndices(t.leafs); + t.tree = MerkleTreeLib.build(t.leafs); + (t.proof, t.flags) = t.tree.leafsMultiProof(t.leafIndices); + t.gathered = _gatherLeafs(t.leafs, t.leafIndices); + assertTrue(MerkleProofLib.verifyMultiProof(t.proof, t.tree.root(), t.gathered, t.flags)); + } + + function _generateUniqueLeafIndices(bytes32[] memory leafs) + internal + returns (uint256[] memory indices) + { + indices = new uint256[](leafs.length); + for (uint256 i; i < leafs.length; ++i) { + indices[i] = i; + } + LibPRNG.PRNG memory prng; + prng.seed(_randomUniform()); + prng.shuffle(indices); + uint256 n = _bound(_random(), 1, indices.length); + /// @solidity memory-safe-assembly + assembly { + mstore(indices, n) + } + LibSort.sort(indices); + } + + function _gatherLeafs(bytes32[] memory leafs, uint256[] memory indices) + internal + pure + returns (bytes32[] memory gathered) + { + gathered = new bytes32[](indices.length); + for (uint256 i; i < indices.length; ++i) { + gathered[i] = leafs[indices[i]]; + } + } +}