2121import org .elasticsearch .search .sort .BucketedSort ;
2222import org .elasticsearch .search .sort .SortOrder ;
2323
24- import java .util .Arrays ;
2524import java .util .stream .IntStream ;
2625
2726/**
@@ -100,7 +99,7 @@ public void collect(BytesRef value, int bucket) {
10099 if (common .inHeapMode (bucket )) {
101100 if (betterThan (value , get (rootIndex , scratch1 ))) {
102101 set (rootIndex , value );
103- downHeap (rootIndex , 0 );
102+ downHeap (rootIndex , 0 , common . bucketSize );
104103 }
105104 return ;
106105 }
@@ -115,7 +114,7 @@ public void collect(BytesRef value, int bucket) {
115114 set (index , value );
116115 if (next == 0 ) {
117116 common .enableHeapMode (bucket );
118- heapify (rootIndex );
117+ heapify (rootIndex , common . bucketSize );
119118 } else {
120119 setNextGatherOffset (rootIndex , next - 1 );
121120 }
@@ -163,14 +162,12 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) {
163162 return blockFactory .newConstantNullBlock (selected .getPositionCount ());
164163 }
165164
166- // Used to sort the values in the bucket.
167- var bucketValues = new BytesRef [common .bucketSize ];
168-
169165 try (var builder = blockFactory .newBytesRefBlockBuilder (selected .getPositionCount ())) {
170166 for (int s = 0 ; s < selected .getPositionCount (); s ++) {
171167 int bucket = selected .getInt (s );
172168
173169 var bounds = getBucketValuesIndexes (bucket );
170+ var rootIndex = bounds .v1 ();
174171 var size = bounds .v2 () - bounds .v1 ();
175172
176173 if (size == 0 ) {
@@ -179,26 +176,19 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) {
179176 }
180177
181178 if (size == 1 ) {
182- builder .appendBytesRef (get (bounds . v1 () , scratch1 ));
179+ builder .appendBytesRef (get (rootIndex , scratch1 ));
183180 continue ;
184181 }
185182
186- for (int i = 0 ; i < size ; i ++) {
187- bucketValues [i ] = get (bounds .v1 () + i , new BytesRef ());
183+ // If we are in the gathering mode, we need to heapify before sorting.
184+ if (common .inHeapMode (bucket ) == false ) {
185+ heapify (rootIndex , (int ) size );
188186 }
189-
190- // TODO: Make use of heap structures to faster iterate in order instead of copying and sorting
191- Arrays .sort (bucketValues , 0 , (int ) size );
187+ heapSort (rootIndex , (int ) size );
192188
193189 builder .beginPositionEntry ();
194- if (common .order == SortOrder .ASC ) {
195- for (int i = 0 ; i < size ; i ++) {
196- builder .appendBytesRef (bucketValues [i ]);
197- }
198- } else {
199- for (int i = (int ) size - 1 ; i >= 0 ; i --) {
200- builder .appendBytesRef (bucketValues [i ]);
201- }
190+ for (int i = 0 ; i < size ; i ++) {
191+ builder .appendBytesRef (get (rootIndex + i , new BytesRef ()));
202192 }
203193 builder .endPositionEntry ();
204194 }
@@ -319,10 +309,28 @@ private void fillGatherOffsets(long startingAt) {
319309 * </ul>
320310 * @param rootIndex the index the start of the bucket
321311 */
322- private void heapify (long rootIndex ) {
323- int maxParent = common . bucketSize / 2 - 1 ;
312+ private void heapify (long rootIndex , int heapSize ) {
313+ int maxParent = heapSize / 2 - 1 ;
324314 for (int parent = maxParent ; parent >= 0 ; parent --) {
325- downHeap (rootIndex , parent );
315+ downHeap (rootIndex , parent , heapSize );
316+ }
317+ }
318+
319+ /**
320+ * Sorts all the values in the heap using heap sort algorithm.
321+ * This runs in {@code O(n log n)} time.
322+ * @param rootIndex index of the start of the bucket
323+ * @param heapSize Number of values that belong to the heap.
324+ * Can be less than bucketSize.
325+ * In such a case, the remaining values in range
326+ * (rootIndex + heapSize, rootIndex + bucketSize)
327+ * are *not* considered part of the heap.
328+ */
329+ private void heapSort (long rootIndex , int heapSize ) {
330+ while (heapSize > 0 ) {
331+ swap (rootIndex , rootIndex + heapSize - 1 );
332+ heapSize --;
333+ downHeap (rootIndex , 0 , heapSize );
326334 }
327335 }
328336
@@ -332,22 +340,27 @@ private void heapify(long rootIndex) {
332340 * @param rootIndex index of the start of the bucket
333341 * @param parent Index within the bucket of the parent to check.
334342 * For example, 0 is the "root".
343+ * @param heapSize Number of values that belong to the heap.
344+ * Can be less than bucketSize.
345+ * In such a case, the remaining values in range
346+ * (rootIndex + heapSize, rootIndex + bucketSize)
347+ * are *not* considered part of the heap.
335348 */
336- private void downHeap (long rootIndex , int parent ) {
349+ private void downHeap (long rootIndex , int parent , int heapSize ) {
337350 while (true ) {
338351 long parentIndex = rootIndex + parent ;
339352 int worst = parent ;
340353 long worstIndex = parentIndex ;
341354 int leftChild = parent * 2 + 1 ;
342355 long leftIndex = rootIndex + leftChild ;
343- if (leftChild < common . bucketSize ) {
356+ if (leftChild < heapSize ) {
344357 if (betterThan (get (worstIndex , scratch1 ), get (leftIndex , scratch2 ))) {
345358 worst = leftChild ;
346359 worstIndex = leftIndex ;
347360 }
348361 int rightChild = leftChild + 1 ;
349362 long rightIndex = rootIndex + rightChild ;
350- if (rightChild < common . bucketSize && betterThan (get (worstIndex , scratch1 ), get (rightIndex , scratch2 ))) {
363+ if (rightChild < heapSize && betterThan (get (worstIndex , scratch1 ), get (rightIndex , scratch2 ))) {
351364 worst = rightChild ;
352365 worstIndex = rightIndex ;
353366 }
0 commit comments