diff --git a/src/utils/MerkleTreeLib.sol b/src/utils/MerkleTreeLib.sol index 8d673a691d..72a0850a2f 100644 --- a/src/utils/MerkleTreeLib.sol +++ b/src/utils/MerkleTreeLib.sol @@ -113,22 +113,14 @@ library MerkleTreeLib { pure returns (bytes32[] memory result) { + uint256 nodeIndex; /// @solidity memory-safe-assembly assembly { - result := mload(0x40) let n := mload(t) - if iszero(lt(leafIndex, sub(n, shr(1, n)))) { - mstore(0x00, 0x7a856a38) // `MerkleTreeOutOfBoundsAccess()`. - revert(0x1c, 0x04) - } - let o := add(result, 0x20) - for { let i := sub(n, add(1, leafIndex)) } i { i := shr(1, sub(i, 1)) } { - mstore(o, mload(add(t, shl(5, add(i, shl(1, and(1, i))))))) - o := add(o, 0x20) - } - mstore(0x40, o) // Allocate memory. - mstore(result, shr(5, sub(o, add(result, 0x20)))) // Store length. + nodeIndex := sub(n, add(1, leafIndex)) + if iszero(lt(leafIndex, sub(n, shr(1, n)))) { nodeIndex := not(0) } } + result = nodeProof(t, nodeIndex); } /// @dev Returns the proof for the node at `nodeIndex`. @@ -167,16 +159,12 @@ library MerkleTreeLib { function gen(leafIndices_, t_, proof_, flags_) -> _flagsLen, _proofLen { let q_ := mload(0x40) // Circular buffer. let c_ := mload(leafIndices_) // Capacity of circular buffer. - let e_ := mload(leafIndices_) // End index of circular buffer. + let e_ := c_ // End index of circular buffer. let b_ := 0 // Start index of circular buffer. - if iszero(e_) { - mstore(0x00, 0xe9729976) // `MerkleTreeInvalidLeafIndices()`. - revert(0x1c, 0x04) - } for { let n_ := mload(t_) // Num nodes. let l_ := sub(n_, shr(1, n_)) // Num leafs. - let p_ := 0 + let p_ := not(0) let i_ := 0 } 1 {} { let j_ := mload(add(add(leafIndices_, 0x20), shl(5, i_))) // Leaf index. @@ -185,7 +173,7 @@ library MerkleTreeLib { mstore(0x00, 0x7a856a38) // `MerkleTreeOutOfBoundsAccess()`. revert(0x1c, 0x04) } - if iszero(or(iszero(i_), gt(j_, p_))) { + if iszero(sgt(j_, p_)) { mstore(0x00, 0xe9729976) // `MerkleTreeInvalidLeafIndices()`. revert(0x1c, 0x04) } @@ -204,10 +192,10 @@ library MerkleTreeLib { _flagsLen := add(_flagsLen, 0x20) let f_ := and(eq(s_, add(1, mload(add(q_, shl(5, mod(b_, c_)))))), lt(b_, e_)) b_ := add(b_, f_) - if flags_ { mstore(add(flags_, _flagsLen), f_) } - if iszero(f_) { - _proofLen := add(_proofLen, 0x20) - if flags_ { mstore(add(proof_, _proofLen), mload(add(t_, shl(5, s_)))) } + _proofLen := add(_proofLen, shl(5, iszero(f_))) + if flags_ { + mstore(add(flags_, _flagsLen), f_) + mstore(mul(iszero(f_), add(proof_, _proofLen)), mload(add(t_, shl(5, s_)))) } mstore(add(q_, shl(5, mod(e_, c_))), shr(1, sub(j_, 1))) e_ := add(e_, 1) @@ -215,6 +203,10 @@ library MerkleTreeLib { _proofLen := shr(5, _proofLen) _flagsLen := shr(5, _flagsLen) } + if iszero(mload(leafIndices)) { + mstore(0x00, 0xe9729976) // `MerkleTreeInvalidLeafIndices()`. + revert(0x1c, 0x04) + } let flagsLen, proofLen := gen(leafIndices, t, 0x00, 0x00) proof := mload(0x40) mstore(proof, proofLen) diff --git a/test/MerkleTreeLib.t.sol b/test/MerkleTreeLib.t.sol index 053a3a2334..b34adfdac0 100644 --- a/test/MerkleTreeLib.t.sol +++ b/test/MerkleTreeLib.t.sol @@ -216,4 +216,13 @@ contract MerkleTreeLibTest is SoladyTest { gathered[i] = leafs[indices[i]]; } } + + function testMultiProofRevertsForEmptyLeafs() public { + vm.expectRevert(MerkleTreeLib.MerkleTreeInvalidLeafIndices.selector); + this.multiProofRevertsForEmptyLeafs(); + } + + function multiProofRevertsForEmptyLeafs() public pure { + (new bytes32[](1)).leafsMultiProof(new uint256[](0)); + } }