Skip to content

Commit 0c19df1

Browse files
authored
⚡️ Optimize groupSum (#1435)
1 parent f281ae5 commit 0c19df1

File tree

1 file changed

+48
-29
lines changed

1 file changed

+48
-29
lines changed

src/utils/LibSort.sol

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -583,43 +583,62 @@ library LibSort {
583583

584584
/// @dev Sorts and uniquifies `keys`. Updates `values` with the grouped sums by key.
585585
function groupSum(uint256[] memory keys, uint256[] memory values) internal pure {
586-
uint256 m;
587586
/// @solidity memory-safe-assembly
588587
assembly {
589-
m := mload(0x40) // Cache the free memory pointer, for freeing the memory.
590-
if iszero(eq(mload(keys), mload(values))) {
588+
function mswap(i_, j_) {
589+
let t_ := mload(i_)
590+
mstore(i_, mload(j_))
591+
mstore(j_, t_)
592+
}
593+
function sortInner(l_, h_, d_) {
594+
let p_ := mload(l_)
595+
let j_ := l_
596+
for { let i_ := add(l_, 0x20) } 1 {} {
597+
if lt(mload(i_), p_) {
598+
j_ := add(j_, 0x20)
599+
mswap(i_, j_)
600+
mswap(add(i_, d_), add(j_, d_))
601+
}
602+
i_ := add(0x20, i_)
603+
if iszero(lt(i_, h_)) { break }
604+
}
605+
mswap(l_, j_)
606+
mswap(add(l_, d_), add(j_, d_))
607+
if iszero(gt(add(0x40, l_), j_)) { sortInner(l_, j_, d_) }
608+
if iszero(gt(add(0x60, j_), h_)) { sortInner(add(j_, 0x20), h_, d_) }
609+
}
610+
let n := mload(values)
611+
if iszero(eq(mload(keys), n)) {
591612
mstore(0x00, 0x4e487b71)
592613
mstore(0x20, 0x32) // Array out of bounds panic if the arrays lengths differ.
593614
revert(0x1c, 0x24)
594615
}
595-
}
596-
if (keys.length == uint256(0)) return;
597-
(uint256[] memory oriKeys, uint256[] memory oriValues) = (copy(keys), copy(values));
598-
insertionSort(keys); // Optimize for small `n` and bytecode size.
599-
uniquifySorted(keys);
600-
/// @solidity memory-safe-assembly
601-
assembly {
602-
let d := sub(values, keys)
603-
let w := not(0x1f)
604-
let s := add(keys, 0x20) // Location of `keys[0]`.
605-
mstore(values, mload(keys)) // Truncate.
606-
calldatacopy(add(s, d), calldatasize(), shl(5, mload(keys))) // Zeroize.
607-
for { let i := shl(5, mload(oriKeys)) } 1 {} {
608-
let k := mload(add(oriKeys, i))
609-
let v := mload(add(oriValues, i))
610-
let j := s // Just do a linear scan to optimize for small `n` and bytecode size.
611-
for {} iszero(eq(mload(j), k)) {} { j := add(j, 0x20) }
612-
j := add(j, d) // Convert `j` to point into `values`.
613-
mstore(j, add(mload(j), v))
614-
if lt(mload(j), v) {
615-
mstore(0x00, 0x4e487b71)
616-
mstore(0x20, 0x11) // Overflow panic if the addition overflows.
617-
revert(0x1c, 0x24)
616+
if iszero(lt(n, 2)) {
617+
let d := sub(values, keys)
618+
let x := add(keys, 0x20)
619+
let end := add(x, shl(5, n))
620+
sortInner(x, end, d)
621+
let s := mload(add(x, d))
622+
for { let y := add(keys, 0x40) } 1 {} {
623+
if iszero(eq(mload(x), mload(y))) {
624+
mstore(add(x, d), s) // Write sum.
625+
s := 0
626+
x := add(x, 0x20)
627+
mstore(x, mload(y))
628+
}
629+
s := add(s, mload(add(y, d)))
630+
if lt(s, mload(add(y, d))) {
631+
mstore(0x00, 0x4e487b71)
632+
mstore(0x20, 0x11) // Overflow panic if the addition overflows.
633+
revert(0x1c, 0x24)
634+
}
635+
y := add(y, 0x20)
636+
if eq(y, end) { break }
618637
}
619-
i := add(i, w) // `sub(i, 0x20)`.
620-
if iszero(i) { break }
638+
mstore(add(x, d), s) // Write sum.
639+
mstore(keys, shr(5, sub(x, keys))) // Truncate.
640+
mstore(values, mload(keys)) // Truncate.
621641
}
622-
mstore(0x40, m) // Frees the memory allocated for the temporary copies.
623642
}
624643
}
625644

0 commit comments

Comments
 (0)