diff --git a/src/utils/LibBytes.sol b/src/utils/LibBytes.sol index 8e6fa17703..7541b5b693 100644 --- a/src/utils/LibBytes.sol +++ b/src/utils/LibBytes.sol @@ -842,6 +842,32 @@ library LibBytes { } } + /// @dev Checks if `x` is in `a`. Assumes `a` has been checked. + function checkInCalldata(bytes calldata x, bytes calldata a) internal pure { + /// @solidity memory-safe-assembly + assembly { + if or( + or(lt(x.offset, a.offset), gt(add(x.offset, x.length), add(a.length, a.offset))), + shr(64, or(x.length, x.offset)) + ) { revert(0x00, 0x00) } + } + } + + /// @dev Checks if `x` is in `a`. Assumes `a` has been checked. + function checkInCalldata(bytes[] calldata x, bytes calldata a) internal pure { + /// @solidity memory-safe-assembly + assembly { + let e := sub(add(a.length, a.offset), 0x20) + if or(lt(x.offset, a.offset), shr(64, x.offset)) { revert(0x00, 0x00) } + for { let i := 0 } iszero(eq(x.length, i)) { i := add(i, 1) } { + let o := calldataload(add(x.offset, shl(5, i))) + let t := add(o, x.offset) + let l := calldataload(t) + if or(shr(64, or(l, o)), gt(add(t, l), e)) { revert(0x00, 0x00) } + } + } + } + /// @dev Returns empty calldata bytes. For silencing the compiler. function emptyCalldata() internal pure returns (bytes calldata result) { /// @solidity memory-safe-assembly diff --git a/src/utils/g/LibBytes.sol b/src/utils/g/LibBytes.sol index 6b384e0d50..a536ad4d7d 100644 --- a/src/utils/g/LibBytes.sol +++ b/src/utils/g/LibBytes.sol @@ -846,6 +846,32 @@ library LibBytes { } } + /// @dev Checks if `x` is in `a`. Assumes `a` has been checked. + function checkInCalldata(bytes calldata x, bytes calldata a) internal pure { + /// @solidity memory-safe-assembly + assembly { + if or( + or(lt(x.offset, a.offset), gt(add(x.offset, x.length), add(a.length, a.offset))), + shr(64, or(x.length, x.offset)) + ) { revert(0x00, 0x00) } + } + } + + /// @dev Checks if `x` is in `a`. Assumes `a` has been checked. + function checkInCalldata(bytes[] calldata x, bytes calldata a) internal pure { + /// @solidity memory-safe-assembly + assembly { + let e := sub(add(a.length, a.offset), 0x20) + if or(lt(x.offset, a.offset), shr(64, x.offset)) { revert(0x00, 0x00) } + for { let i := 0 } iszero(eq(x.length, i)) { i := add(i, 1) } { + let o := calldataload(add(x.offset, shl(5, i))) + let t := add(o, x.offset) + let l := calldataload(t) + if or(shr(64, or(l, o)), gt(add(t, l), e)) { revert(0x00, 0x00) } + } + } + } + /// @dev Returns empty calldata bytes. For silencing the compiler. function emptyCalldata() internal pure returns (bytes calldata result) { /// @solidity memory-safe-assembly diff --git a/test/LibBytes.t.sol b/test/LibBytes.t.sol index 1f92dc974a..a1285f8537 100644 --- a/test/LibBytes.t.sol +++ b/test/LibBytes.t.sol @@ -373,4 +373,48 @@ contract LibBytesTest is SoladyTest { assertEq(uint160(LibBytes.msbToAddress(x)), msb); assertEq(uint160(LibBytes.lsbToAddress(x)), lsb); } + + function testCheckInCalldata(bytes memory child) public view { + this.checkInCalldata(child, abi.encode(child)); + } + + function testCheckInCalldata() public pure { + LibBytes.checkInCalldata(msg.data, msg.data); + } + + function checkInCalldata(bytes calldata expectedChild, bytes calldata encoded) public pure { + bytes calldata child; + /// @solidity memory-safe-assembly + assembly { + child.offset := add(0x20, add(encoded.offset, calldataload(encoded.offset))) + child.length := calldataload(add(encoded.offset, calldataload(encoded.offset))) + } + LibBytes.checkInCalldata(child, encoded); + LibBytes.checkInCalldata(child, msg.data); + LibBytes.checkInCalldata(encoded, msg.data); + require(keccak256(expectedChild) == keccak256(child)); + } + + function testCheckInCalldata(bytes[] memory children) public view { + this.checkInCalldata(children, abi.encode(children)); + } + + function checkInCalldata(bytes[] calldata expectedChildren, bytes calldata encoded) + public + pure + { + bytes[] calldata children; + /// @solidity memory-safe-assembly + assembly { + children.offset := add(0x20, add(encoded.offset, calldataload(encoded.offset))) + children.length := calldataload(add(encoded.offset, calldataload(encoded.offset))) + } + LibBytes.checkInCalldata(children, encoded); + LibBytes.checkInCalldata(expectedChildren, msg.data); + LibBytes.checkInCalldata(children, msg.data); + require(expectedChildren.length == children.length); + for (uint256 i; i < children.length; ++i) { + require(keccak256(expectedChildren[i]) == keccak256(children[i])); + } + } }