@@ -13,7 +13,7 @@ library Arrays {
13
13
using StorageSlot for bytes32 ;
14
14
15
15
/**
16
- * @dev Sort an array (in memory) in increasing order .
16
+ * @dev Sort an array of bytes32 (in memory) following the provided comparator function .
17
17
*
18
18
* This function does the sorting "in place", meaning that it overrides the input. The object is returned for
19
19
* convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array.
@@ -23,55 +23,167 @@ library Arrays {
23
23
* when executing this as part of a transaction. If the array being sorted is too large, the sort operation may
24
24
* consume more gas than is available in a block, leading to potential DoS.
25
25
*/
26
+ function sort (
27
+ bytes32 [] memory array ,
28
+ function (bytes32 , bytes32 ) pure returns (bool ) comp
29
+ ) internal pure returns (bytes32 [] memory ) {
30
+ _quickSort (_begin (array), _end (array), comp);
31
+ return array;
32
+ }
33
+
34
+ /**
35
+ * @dev Variant of {sort} that sorts an array of bytes32 in increasing order.
36
+ */
37
+ function sort (bytes32 [] memory array ) internal pure returns (bytes32 [] memory ) {
38
+ return sort (array, _defaultComp);
39
+ }
40
+
41
+ /**
42
+ * @dev Variant of {sort} that sorts an array of address following a provided comparator function.
43
+ */
44
+ function sort (
45
+ address [] memory array ,
46
+ function (address , address ) pure returns (bool ) comp
47
+ ) internal pure returns (address [] memory ) {
48
+ sort (_castToBytes32Array (array), _castToBytes32Comp (comp));
49
+ return array;
50
+ }
51
+
52
+ /**
53
+ * @dev Variant of {sort} that sorts an array of address in increasing order.
54
+ */
55
+ function sort (address [] memory array ) internal pure returns (address [] memory ) {
56
+ sort (_castToBytes32Array (array), _defaultComp);
57
+ return array;
58
+ }
59
+
60
+ /**
61
+ * @dev Variant of {sort} that sorts an array of uint256 following a provided comparator function.
62
+ */
63
+ function sort (
64
+ uint256 [] memory array ,
65
+ function (uint256 , uint256 ) pure returns (bool ) comp
66
+ ) internal pure returns (uint256 [] memory ) {
67
+ sort (_castToBytes32Array (array), _castToBytes32Comp (comp));
68
+ return array;
69
+ }
70
+
71
+ /**
72
+ * @dev Variant of {sort} that sorts an array of uint256 in increasing order.
73
+ */
26
74
function sort (uint256 [] memory array ) internal pure returns (uint256 [] memory ) {
27
- _quickSort ( array, 0 , array. length );
75
+ sort ( _castToBytes32Array ( array), _defaultComp );
28
76
return array;
29
77
}
30
78
31
79
/**
32
- * @dev Performs a quick sort on an array in memory. The array is sorted in increasing order.
80
+ * @dev Performs a quick sort of a segment of memory. The segment sorted starts at `begin` (inclusive), and stops
81
+ * at end (exclusive). Sorting follows the `comp` comparator.
33
82
*
34
- * Invariant: `i <= j <= array.length`. This is the case when initially called by {sort} and is preserved in
35
- * subcalls.
83
+ * Invariant: `begin <= end`. This is the case when initially called by {sort} and is preserved in subcalls.
84
+ *
85
+ * IMPORTANT: Memory locations between `begin` and `end` are not validated/zeroed. This function should
86
+ * be used only if the limits are within a memory array.
36
87
*/
37
- function _quickSort (uint256 [] memory array , uint256 i , uint256 j ) private pure {
88
+ function _quickSort (uint256 begin , uint256 end , function ( bytes32 , bytes32 ) pure returns ( bool ) comp ) private pure {
38
89
unchecked {
39
- // Can't overflow given `i <= j`
40
- if (j - i < 2 ) return ;
90
+ if (end - begin < 0x40 ) return ;
41
91
42
92
// Use first element as pivot
43
- uint256 pivot = unsafeMemoryAccess (array, i );
93
+ bytes32 pivot = _mload (begin );
44
94
// Position where the pivot should be at the end of the loop
45
- uint256 index = i;
46
-
47
- for (uint256 k = i + 1 ; k < j; ++ k) {
48
- // Unsafe access is safe given `k < j <= array.length`.
49
- if (unsafeMemoryAccess (array, k) < pivot) {
50
- // If array[k] is smaller than the pivot, we increment the index and move array[k] there.
51
- _swap (array, ++ index, k);
95
+ uint256 pos = begin;
96
+
97
+ for (uint256 it = begin + 0x20 ; it < end; it += 0x20 ) {
98
+ if (comp (_mload (it), pivot)) {
99
+ // If the value stored at the iterator's position comes before the pivot, we increment the
100
+ // position of the pivot and move the value there.
101
+ pos += 0x20 ;
102
+ _swap (pos, it);
52
103
}
53
104
}
54
105
55
- // Swap pivot into place
56
- _swap (array, i, index);
106
+ _swap (begin, pos); // Swap pivot into place
107
+ _quickSort (begin, pos, comp); // Sort the left side of the pivot
108
+ _quickSort (pos + 0x20 , end, comp); // Sort the right side of the pivot
109
+ }
110
+ }
111
+
112
+ /**
113
+ * @dev Pointer to the memory location of the first element of `array`.
114
+ */
115
+ function _begin (bytes32 [] memory array ) private pure returns (uint256 ptr ) {
116
+ /// @solidity memory-safe-assembly
117
+ assembly {
118
+ ptr := add (array, 0x20 )
119
+ }
120
+ }
121
+
122
+ /**
123
+ * @dev Pointer to the memory location of the first memory word (32bytes) after `array`. This is the memory word
124
+ * that comes just after the last element of the array.
125
+ */
126
+ function _end (bytes32 [] memory array ) private pure returns (uint256 ptr ) {
127
+ unchecked {
128
+ return _begin (array) + array.length * 0x20 ;
129
+ }
130
+ }
57
131
58
- _quickSort (array, i, index); // Sort the left side of the pivot
59
- _quickSort (array, index + 1 , j); // Sort the right side of the pivot
132
+ /**
133
+ * @dev Load memory word (as a bytes32) at location `ptr`.
134
+ */
135
+ function _mload (uint256 ptr ) private pure returns (bytes32 value ) {
136
+ assembly {
137
+ value := mload (ptr)
60
138
}
61
139
}
62
140
63
141
/**
64
- * @dev Swaps the elements at positions `i ` and `j` in the `arr` array .
142
+ * @dev Swaps the elements memory location `ptr1 ` and `ptr2` .
65
143
*/
66
- function _swap (uint256 [] memory arr , uint256 i , uint256 j ) private pure {
144
+ function _swap (uint256 ptr1 , uint256 ptr2 ) private pure {
145
+ assembly {
146
+ let value1 := mload (ptr1)
147
+ let value2 := mload (ptr2)
148
+ mstore (ptr1, value2)
149
+ mstore (ptr2, value1)
150
+ }
151
+ }
152
+
153
+ /// @dev Comparator for sorting arrays in increasing order.
154
+ function _defaultComp (bytes32 a , bytes32 b ) private pure returns (bool ) {
155
+ return a < b;
156
+ }
157
+
158
+ /// @dev Helper: low level cast address memory array to uint256 memory array
159
+ function _castToBytes32Array (address [] memory input ) private pure returns (bytes32 [] memory output ) {
160
+ assembly {
161
+ output := input
162
+ }
163
+ }
164
+
165
+ /// @dev Helper: low level cast uint256 memory array to uint256 memory array
166
+ function _castToBytes32Array (uint256 [] memory input ) private pure returns (bytes32 [] memory output ) {
167
+ assembly {
168
+ output := input
169
+ }
170
+ }
171
+
172
+ /// @dev Helper: low level cast address comp function to bytes32 comp function
173
+ function _castToBytes32Comp (
174
+ function (address , address ) pure returns (bool ) input
175
+ ) private pure returns (function (bytes32 , bytes32 ) pure returns (bool ) output) {
176
+ assembly {
177
+ output := input
178
+ }
179
+ }
180
+
181
+ /// @dev Helper: low level cast uint256 comp function to bytes32 comp function
182
+ function _castToBytes32Comp (
183
+ function (uint256 , uint256 ) pure returns (bool ) input
184
+ ) private pure returns (function (bytes32 , bytes32 ) pure returns (bool ) output) {
67
185
assembly {
68
- let start := add (arr, 0x20 ) // Pointer to the first element of the array
69
- let pos_i := add (start, mul (i, 0x20 ))
70
- let pos_j := add (start, mul (j, 0x20 ))
71
- let val_i := mload (pos_i)
72
- let val_j := mload (pos_j)
73
- mstore (pos_i, val_j)
74
- mstore (pos_j, val_i)
186
+ output := input
75
187
}
76
188
}
77
189
0 commit comments