Skip to content

Commit f8b1ddf

Browse files
Amxxernestognw
andauthored
Add variants of Array.sort for address[] and bytes32[] (#4883)
Co-authored-by: Ernesto García <[email protected]>
1 parent 72c0da9 commit f8b1ddf

File tree

12 files changed

+263
-116
lines changed

12 files changed

+263
-116
lines changed

.changeset/dirty-cobras-smile.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
'openzeppelin-solidity': minor
33
---
44

5-
`Arrays`: add a `sort` function.
5+
`Arrays`: add a `sort` functions for `address[]`, `bytes32[]` and `uint256[]` memory arrays.

contracts/mocks/ArraysMock.sol

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ contract Uint256ArraysMock {
3636
function unsafeAccess(uint256 pos) external view returns (uint256) {
3737
return _array.unsafeAccess(pos).value;
3838
}
39+
40+
function sort(uint256[] memory array) external pure returns (uint256[] memory) {
41+
return array.sort();
42+
}
43+
44+
function sortReverse(uint256[] memory array) external pure returns (uint256[] memory) {
45+
return array.sort(_reverse);
46+
}
47+
48+
function _reverse(uint256 a, uint256 b) private pure returns (bool) {
49+
return a > b;
50+
}
3951
}
4052

4153
contract AddressArraysMock {
@@ -50,6 +62,18 @@ contract AddressArraysMock {
5062
function unsafeAccess(uint256 pos) external view returns (address) {
5163
return _array.unsafeAccess(pos).value;
5264
}
65+
66+
function sort(address[] memory array) external pure returns (address[] memory) {
67+
return array.sort();
68+
}
69+
70+
function sortReverse(address[] memory array) external pure returns (address[] memory) {
71+
return array.sort(_reverse);
72+
}
73+
74+
function _reverse(address a, address b) private pure returns (bool) {
75+
return uint160(a) > uint160(b);
76+
}
5377
}
5478

5579
contract Bytes32ArraysMock {
@@ -64,4 +88,16 @@ contract Bytes32ArraysMock {
6488
function unsafeAccess(uint256 pos) external view returns (bytes32) {
6589
return _array.unsafeAccess(pos).value;
6690
}
91+
92+
function sort(bytes32[] memory array) external pure returns (bytes32[] memory) {
93+
return array.sort();
94+
}
95+
96+
function sortReverse(bytes32[] memory array) external pure returns (bytes32[] memory) {
97+
return array.sort(_reverse);
98+
}
99+
100+
function _reverse(bytes32 a, bytes32 b) private pure returns (bool) {
101+
return uint256(a) > uint256(b);
102+
}
67103
}

contracts/utils/Arrays.sol

Lines changed: 141 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ library Arrays {
1313
using StorageSlot for bytes32;
1414

1515
/**
16-
* @dev Sort an array (in memory) in increasing order.
16+
* @dev Sort an array of bytes32 (in memory) following the provided comparator function.
1717
*
1818
* This function does the sorting "in place", meaning that it overrides the input. The object is returned for
1919
* convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array.
@@ -23,55 +23,167 @@ library Arrays {
2323
* when executing this as part of a transaction. If the array being sorted is too large, the sort operation may
2424
* consume more gas than is available in a block, leading to potential DoS.
2525
*/
26+
function sort(
27+
bytes32[] memory array,
28+
function(bytes32, bytes32) pure returns (bool) comp
29+
) internal pure returns (bytes32[] memory) {
30+
_quickSort(_begin(array), _end(array), comp);
31+
return array;
32+
}
33+
34+
/**
35+
* @dev Variant of {sort} that sorts an array of bytes32 in increasing order.
36+
*/
37+
function sort(bytes32[] memory array) internal pure returns (bytes32[] memory) {
38+
return sort(array, _defaultComp);
39+
}
40+
41+
/**
42+
* @dev Variant of {sort} that sorts an array of address following a provided comparator function.
43+
*/
44+
function sort(
45+
address[] memory array,
46+
function(address, address) pure returns (bool) comp
47+
) internal pure returns (address[] memory) {
48+
sort(_castToBytes32Array(array), _castToBytes32Comp(comp));
49+
return array;
50+
}
51+
52+
/**
53+
* @dev Variant of {sort} that sorts an array of address in increasing order.
54+
*/
55+
function sort(address[] memory array) internal pure returns (address[] memory) {
56+
sort(_castToBytes32Array(array), _defaultComp);
57+
return array;
58+
}
59+
60+
/**
61+
* @dev Variant of {sort} that sorts an array of uint256 following a provided comparator function.
62+
*/
63+
function sort(
64+
uint256[] memory array,
65+
function(uint256, uint256) pure returns (bool) comp
66+
) internal pure returns (uint256[] memory) {
67+
sort(_castToBytes32Array(array), _castToBytes32Comp(comp));
68+
return array;
69+
}
70+
71+
/**
72+
* @dev Variant of {sort} that sorts an array of uint256 in increasing order.
73+
*/
2674
function sort(uint256[] memory array) internal pure returns (uint256[] memory) {
27-
_quickSort(array, 0, array.length);
75+
sort(_castToBytes32Array(array), _defaultComp);
2876
return array;
2977
}
3078

3179
/**
32-
* @dev Performs a quick sort on an array in memory. The array is sorted in increasing order.
80+
* @dev Performs a quick sort of a segment of memory. The segment sorted starts at `begin` (inclusive), and stops
81+
* at end (exclusive). Sorting follows the `comp` comparator.
3382
*
34-
* Invariant: `i <= j <= array.length`. This is the case when initially called by {sort} and is preserved in
35-
* subcalls.
83+
* Invariant: `begin <= end`. This is the case when initially called by {sort} and is preserved in subcalls.
84+
*
85+
* IMPORTANT: Memory locations between `begin` and `end` are not validated/zeroed. This function should
86+
* be used only if the limits are within a memory array.
3687
*/
37-
function _quickSort(uint256[] memory array, uint256 i, uint256 j) private pure {
88+
function _quickSort(uint256 begin, uint256 end, function(bytes32, bytes32) pure returns (bool) comp) private pure {
3889
unchecked {
39-
// Can't overflow given `i <= j`
40-
if (j - i < 2) return;
90+
if (end - begin < 0x40) return;
4191

4292
// Use first element as pivot
43-
uint256 pivot = unsafeMemoryAccess(array, i);
93+
bytes32 pivot = _mload(begin);
4494
// Position where the pivot should be at the end of the loop
45-
uint256 index = i;
46-
47-
for (uint256 k = i + 1; k < j; ++k) {
48-
// Unsafe access is safe given `k < j <= array.length`.
49-
if (unsafeMemoryAccess(array, k) < pivot) {
50-
// If array[k] is smaller than the pivot, we increment the index and move array[k] there.
51-
_swap(array, ++index, k);
95+
uint256 pos = begin;
96+
97+
for (uint256 it = begin + 0x20; it < end; it += 0x20) {
98+
if (comp(_mload(it), pivot)) {
99+
// If the value stored at the iterator's position comes before the pivot, we increment the
100+
// position of the pivot and move the value there.
101+
pos += 0x20;
102+
_swap(pos, it);
52103
}
53104
}
54105

55-
// Swap pivot into place
56-
_swap(array, i, index);
106+
_swap(begin, pos); // Swap pivot into place
107+
_quickSort(begin, pos, comp); // Sort the left side of the pivot
108+
_quickSort(pos + 0x20, end, comp); // Sort the right side of the pivot
109+
}
110+
}
111+
112+
/**
113+
* @dev Pointer to the memory location of the first element of `array`.
114+
*/
115+
function _begin(bytes32[] memory array) private pure returns (uint256 ptr) {
116+
/// @solidity memory-safe-assembly
117+
assembly {
118+
ptr := add(array, 0x20)
119+
}
120+
}
121+
122+
/**
123+
* @dev Pointer to the memory location of the first memory word (32bytes) after `array`. This is the memory word
124+
* that comes just after the last element of the array.
125+
*/
126+
function _end(bytes32[] memory array) private pure returns (uint256 ptr) {
127+
unchecked {
128+
return _begin(array) + array.length * 0x20;
129+
}
130+
}
57131

58-
_quickSort(array, i, index); // Sort the left side of the pivot
59-
_quickSort(array, index + 1, j); // Sort the right side of the pivot
132+
/**
133+
* @dev Load memory word (as a bytes32) at location `ptr`.
134+
*/
135+
function _mload(uint256 ptr) private pure returns (bytes32 value) {
136+
assembly {
137+
value := mload(ptr)
60138
}
61139
}
62140

63141
/**
64-
* @dev Swaps the elements at positions `i` and `j` in the `arr` array.
142+
* @dev Swaps the elements memory location `ptr1` and `ptr2`.
65143
*/
66-
function _swap(uint256[] memory arr, uint256 i, uint256 j) private pure {
144+
function _swap(uint256 ptr1, uint256 ptr2) private pure {
145+
assembly {
146+
let value1 := mload(ptr1)
147+
let value2 := mload(ptr2)
148+
mstore(ptr1, value2)
149+
mstore(ptr2, value1)
150+
}
151+
}
152+
153+
/// @dev Comparator for sorting arrays in increasing order.
154+
function _defaultComp(bytes32 a, bytes32 b) private pure returns (bool) {
155+
return a < b;
156+
}
157+
158+
/// @dev Helper: low level cast address memory array to uint256 memory array
159+
function _castToBytes32Array(address[] memory input) private pure returns (bytes32[] memory output) {
160+
assembly {
161+
output := input
162+
}
163+
}
164+
165+
/// @dev Helper: low level cast uint256 memory array to uint256 memory array
166+
function _castToBytes32Array(uint256[] memory input) private pure returns (bytes32[] memory output) {
167+
assembly {
168+
output := input
169+
}
170+
}
171+
172+
/// @dev Helper: low level cast address comp function to bytes32 comp function
173+
function _castToBytes32Comp(
174+
function(address, address) pure returns (bool) input
175+
) private pure returns (function(bytes32, bytes32) pure returns (bool) output) {
176+
assembly {
177+
output := input
178+
}
179+
}
180+
181+
/// @dev Helper: low level cast uint256 comp function to bytes32 comp function
182+
function _castToBytes32Comp(
183+
function(uint256, uint256) pure returns (bool) input
184+
) private pure returns (function(bytes32, bytes32) pure returns (bool) output) {
67185
assembly {
68-
let start := add(arr, 0x20) // Pointer to the first element of the array
69-
let pos_i := add(start, mul(i, 0x20))
70-
let pos_j := add(start, mul(j, 0x20))
71-
let val_i := mload(pos_i)
72-
let val_j := mload(pos_j)
73-
mstore(pos_i, val_j)
74-
mstore(pos_j, val_i)
186+
output := input
75187
}
76188
}
77189

scripts/helpers.js

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,15 @@ function range(start, stop = undefined, step = 1) {
77
stop = start;
88
start = 0;
99
}
10-
return start < stop
11-
? Array(Math.ceil((stop - start) / step))
12-
.fill()
13-
.map((_, i) => start + i * step)
14-
: [];
10+
return start < stop ? Array.from({ length: Math.ceil((stop - start) / step) }, (_, i) => start + i * step) : [];
1511
}
1612

1713
function unique(array, op = x => x) {
1814
return array.filter((obj, i) => array.findIndex(entry => op(obj) === op(entry)) === i);
1915
}
2016

2117
function zip(...args) {
22-
return Array(Math.max(...args.map(arg => arg.length)))
23-
.fill(null)
24-
.map((_, i) => args.map(arg => arg[i]));
18+
return Array.from({ length: Math.max(...args.map(arg => arg.length)) }, (_, i) => args.map(arg => arg[i]));
2519
}
2620

2721
function capitalize(str) {

test/finance/VestingWallet.test.js

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ async function fixture() {
5555
},
5656
};
5757

58-
const schedule = Array(64)
59-
.fill()
60-
.map((_, i) => (BigInt(i) * duration) / 60n + start);
58+
const schedule = Array.from({ length: 64 }, (_, i) => (BigInt(i) * duration) / 60n + start);
6159

6260
const vestingFn = timestamp => min(amount, (amount * (timestamp - start)) / duration);
6361

test/helpers/iterate.js

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@ const mapValues = (obj, fn) => Object.fromEntries(Object.entries(obj).map(([k, v
55
const product = (...arrays) => arrays.reduce((a, b) => a.flatMap(ai => b.map(bi => [...ai, bi])), [[]]);
66
const unique = (...array) => array.filter((obj, i) => array.indexOf(obj) === i);
77
const zip = (...args) =>
8-
Array(Math.max(...args.map(array => array.length)))
9-
.fill()
10-
.map((_, i) => args.map(array => array[i]));
8+
Array.from({ length: Math.max(...args.map(array => array.length)) }, (_, i) => args.map(array => array[i]));
119

1210
module.exports = {
1311
mapValues,

test/helpers/random.js

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
const { ethers } = require('hardhat');
22

3-
const randomArray = (generator, arrayLength = 3) => Array(arrayLength).fill().map(generator);
4-
53
const generators = {
64
address: () => ethers.Wallet.createRandom().address,
75
bytes32: () => ethers.hexlify(ethers.randomBytes(32)),
@@ -15,6 +13,5 @@ generators.uint256.zero = 0n;
1513
generators.hexBytes.zero = '0x';
1614

1715
module.exports = {
18-
randomArray,
1916
generators,
2017
};

0 commit comments

Comments
 (0)