Skip to content

Commit f2c9c08

Browse files
committed
Use heap sort in BytesRefBucketedSort and IpBucketedSort too.
1 parent 12fa01f commit f2c9c08

File tree

3 files changed

+78
-58
lines changed

3 files changed

+78
-58
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/BooleanBucketedSort.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,6 @@ public BooleanBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize)
5353

5454
/**
5555
* Collects a {@code value} into a {@code bucket}.
56-
* <p>
57-
* It may or may not be inserted in the heap, depending on if it is better than the current root.
58-
* </p>
5956
*/
6057
public void collect(boolean value, int bucket) {
6158
long rootIndex = (long) bucket * 2;

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/BytesRefBucketedSort.java

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.elasticsearch.search.sort.BucketedSort;
2525
import org.elasticsearch.search.sort.SortOrder;
2626

27-
import java.util.Arrays;
2827
import java.util.stream.IntStream;
2928
import java.util.stream.LongStream;
3029

@@ -122,7 +121,7 @@ public void collect(BytesRef value, int bucket) {
122121
if (common.inHeapMode(bucket)) {
123122
if (betterThan(value, values.get(rootIndex).bytesRefView())) {
124123
clearedBytesAt(rootIndex).append(value);
125-
downHeap(rootIndex, 0);
124+
downHeap(rootIndex, 0, common.bucketSize);
126125
}
127126
checkInvariant(bucket);
128127
return;
@@ -138,7 +137,7 @@ public void collect(BytesRef value, int bucket) {
138137
clearedBytesAt(index).append(value);
139138
if (next == 0) {
140139
common.enableHeapMode(bucket);
141-
heapify(rootIndex);
140+
heapify(rootIndex, common.bucketSize);
142141
} else {
143142
ByteUtils.writeIntLE(next - 1, values.get(rootIndex).bytes(), 0);
144143
}
@@ -182,9 +181,6 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) {
182181
return blockFactory.newConstantNullBlock(selected.getPositionCount());
183182
}
184183

185-
// Used to sort the values in the bucket.
186-
BytesRef[] bucketValues = new BytesRef[common.bucketSize];
187-
188184
try (var builder = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) {
189185
for (int s = 0; s < selected.getPositionCount(); s++) {
190186
int bucket = selected.getInt(s);
@@ -212,25 +208,18 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) {
212208
continue;
213209
}
214210

215-
for (int i = 0; i < size; i++) {
216-
try (BreakingBytesRefBuilder bytes = values.get(start + i)) {
217-
bucketValues[i] = bytes.bytesRefView();
218-
}
219-
values.set(start + i, null);
211+
// If we are in the gathering mode, we need to heapify before sorting.
212+
if (common.inHeapMode(bucket) == false) {
213+
heapify(rootIndex, (int) size);
220214
}
221-
222-
// TODO: Make use of heap structures to faster iterate in order instead of copying and sorting
223-
Arrays.sort(bucketValues, 0, (int) size);
215+
heapSort(rootIndex, (int) size);
224216

225217
builder.beginPositionEntry();
226-
if (common.order == SortOrder.ASC) {
227-
for (int i = 0; i < size; i++) {
228-
builder.appendBytesRef(bucketValues[i]);
229-
}
230-
} else {
231-
for (int i = (int) size - 1; i >= 0; i--) {
232-
builder.appendBytesRef(bucketValues[i]);
218+
for (int i = 0; i < size; i++) {
219+
try (BreakingBytesRefBuilder bytes = values.get(start + i)) {
220+
builder.appendBytesRef(bytes.bytesRefView());
233221
}
222+
values.set(start + i, null);
234223
}
235224
builder.endPositionEntry();
236225
}
@@ -339,10 +328,28 @@ private void fillGatherOffsets(long startingAt) {
339328
* </ul>
340329
* @param rootIndex the index the start of the bucket
341330
*/
342-
private void heapify(long rootIndex) {
343-
int maxParent = common.bucketSize / 2 - 1;
331+
private void heapify(long rootIndex, int heapSize) {
332+
int maxParent = heapSize / 2 - 1;
344333
for (int parent = maxParent; parent >= 0; parent--) {
345-
downHeap(rootIndex, parent);
334+
downHeap(rootIndex, parent, heapSize);
335+
}
336+
}
337+
338+
/**
339+
* Sorts all the values in the heap using heap sort algorithm.
340+
* This runs in {@code O(n log n)} time.
341+
* @param rootIndex index of the start of the bucket
342+
* @param heapSize Number of values that belong to the heap.
343+
* Can be less than bucketSize.
344+
* In such a case, the remaining values in range
345+
* (rootIndex + heapSize, rootIndex + bucketSize)
346+
* are *not* considered part of the heap.
347+
*/
348+
private void heapSort(long rootIndex, int heapSize) {
349+
while (heapSize > 0) {
350+
swap(rootIndex, rootIndex + heapSize - 1);
351+
heapSize--;
352+
downHeap(rootIndex, 0, heapSize);
346353
}
347354
}
348355

@@ -352,24 +359,27 @@ private void heapify(long rootIndex) {
352359
* @param rootIndex index of the start of the bucket
353360
* @param parent Index within the bucket of the parent to check.
354361
* For example, 0 is the "root".
362+
* @param heapSize Number of values that belong to the heap.
363+
* Can be less than bucketSize.
364+
* In such a case, the remaining values in range
365+
* (rootIndex + heapSize, rootIndex + bucketSize)
366+
* are *not* considered part of the heap.
355367
*/
356-
private void downHeap(long rootIndex, int parent) {
368+
private void downHeap(long rootIndex, int parent, int heapSize) {
357369
while (true) {
358370
long parentIndex = rootIndex + parent;
359371
int worst = parent;
360372
long worstIndex = parentIndex;
361373
int leftChild = parent * 2 + 1;
362374
long leftIndex = rootIndex + leftChild;
363-
if (leftChild < common.bucketSize) {
375+
if (leftChild < heapSize) {
364376
if (betterThan(values.get(worstIndex).bytesRefView(), values.get(leftIndex).bytesRefView())) {
365377
worst = leftChild;
366378
worstIndex = leftIndex;
367379
}
368380
int rightChild = leftChild + 1;
369381
long rightIndex = rootIndex + rightChild;
370-
if (rightChild < common.bucketSize
371-
&& betterThan(values.get(worstIndex).bytesRefView(), values.get(rightIndex).bytesRefView())) {
372-
382+
if (rightChild < heapSize && betterThan(values.get(worstIndex).bytesRefView(), values.get(rightIndex).bytesRefView())) {
373383
worst = rightChild;
374384
worstIndex = rightIndex;
375385
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/IpBucketedSort.java

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import org.elasticsearch.search.sort.BucketedSort;
2222
import org.elasticsearch.search.sort.SortOrder;
2323

24-
import java.util.Arrays;
2524
import java.util.stream.IntStream;
2625

2726
/**
@@ -100,7 +99,7 @@ public void collect(BytesRef value, int bucket) {
10099
if (common.inHeapMode(bucket)) {
101100
if (betterThan(value, get(rootIndex, scratch1))) {
102101
set(rootIndex, value);
103-
downHeap(rootIndex, 0);
102+
downHeap(rootIndex, 0, common.bucketSize);
104103
}
105104
return;
106105
}
@@ -115,7 +114,7 @@ public void collect(BytesRef value, int bucket) {
115114
set(index, value);
116115
if (next == 0) {
117116
common.enableHeapMode(bucket);
118-
heapify(rootIndex);
117+
heapify(rootIndex, common.bucketSize);
119118
} else {
120119
setNextGatherOffset(rootIndex, next - 1);
121120
}
@@ -163,14 +162,12 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) {
163162
return blockFactory.newConstantNullBlock(selected.getPositionCount());
164163
}
165164

166-
// Used to sort the values in the bucket.
167-
var bucketValues = new BytesRef[common.bucketSize];
168-
169165
try (var builder = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) {
170166
for (int s = 0; s < selected.getPositionCount(); s++) {
171167
int bucket = selected.getInt(s);
172168

173169
var bounds = getBucketValuesIndexes(bucket);
170+
var rootIndex = bounds.v1();
174171
var size = bounds.v2() - bounds.v1();
175172

176173
if (size == 0) {
@@ -179,26 +176,19 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) {
179176
}
180177

181178
if (size == 1) {
182-
builder.appendBytesRef(get(bounds.v1(), scratch1));
179+
builder.appendBytesRef(get(rootIndex, scratch1));
183180
continue;
184181
}
185182

186-
for (int i = 0; i < size; i++) {
187-
bucketValues[i] = get(bounds.v1() + i, new BytesRef());
183+
// If we are in the gathering mode, we need to heapify before sorting.
184+
if (common.inHeapMode(bucket) == false) {
185+
heapify(rootIndex, (int) size);
188186
}
189-
190-
// TODO: Make use of heap structures to faster iterate in order instead of copying and sorting
191-
Arrays.sort(bucketValues, 0, (int) size);
187+
heapSort(rootIndex, (int) size);
192188

193189
builder.beginPositionEntry();
194-
if (common.order == SortOrder.ASC) {
195-
for (int i = 0; i < size; i++) {
196-
builder.appendBytesRef(bucketValues[i]);
197-
}
198-
} else {
199-
for (int i = (int) size - 1; i >= 0; i--) {
200-
builder.appendBytesRef(bucketValues[i]);
201-
}
190+
for (int i = 0; i < size; i++) {
191+
builder.appendBytesRef(get(rootIndex + i, new BytesRef()));
202192
}
203193
builder.endPositionEntry();
204194
}
@@ -319,10 +309,28 @@ private void fillGatherOffsets(long startingAt) {
319309
* </ul>
320310
* @param rootIndex the index the start of the bucket
321311
*/
322-
private void heapify(long rootIndex) {
323-
int maxParent = common.bucketSize / 2 - 1;
312+
private void heapify(long rootIndex, int heapSize) {
313+
int maxParent = heapSize / 2 - 1;
324314
for (int parent = maxParent; parent >= 0; parent--) {
325-
downHeap(rootIndex, parent);
315+
downHeap(rootIndex, parent, heapSize);
316+
}
317+
}
318+
319+
/**
320+
* Sorts all the values in the heap using heap sort algorithm.
321+
* This runs in {@code O(n log n)} time.
322+
* @param rootIndex index of the start of the bucket
323+
* @param heapSize Number of values that belong to the heap.
324+
* Can be less than bucketSize.
325+
* In such a case, the remaining values in range
326+
* (rootIndex + heapSize, rootIndex + bucketSize)
327+
* are *not* considered part of the heap.
328+
*/
329+
private void heapSort(long rootIndex, int heapSize) {
330+
while (heapSize > 0) {
331+
swap(rootIndex, rootIndex + heapSize - 1);
332+
heapSize--;
333+
downHeap(rootIndex, 0, heapSize);
326334
}
327335
}
328336

@@ -332,22 +340,27 @@ private void heapify(long rootIndex) {
332340
* @param rootIndex index of the start of the bucket
333341
* @param parent Index within the bucket of the parent to check.
334342
* For example, 0 is the "root".
343+
* @param heapSize Number of values that belong to the heap.
344+
* Can be less than bucketSize.
345+
* In such a case, the remaining values in range
346+
* (rootIndex + heapSize, rootIndex + bucketSize)
347+
* are *not* considered part of the heap.
335348
*/
336-
private void downHeap(long rootIndex, int parent) {
349+
private void downHeap(long rootIndex, int parent, int heapSize) {
337350
while (true) {
338351
long parentIndex = rootIndex + parent;
339352
int worst = parent;
340353
long worstIndex = parentIndex;
341354
int leftChild = parent * 2 + 1;
342355
long leftIndex = rootIndex + leftChild;
343-
if (leftChild < common.bucketSize) {
356+
if (leftChild < heapSize) {
344357
if (betterThan(get(worstIndex, scratch1), get(leftIndex, scratch2))) {
345358
worst = leftChild;
346359
worstIndex = leftIndex;
347360
}
348361
int rightChild = leftChild + 1;
349362
long rightIndex = rootIndex + rightChild;
350-
if (rightChild < common.bucketSize && betterThan(get(worstIndex, scratch1), get(rightIndex, scratch2))) {
363+
if (rightChild < heapSize && betterThan(get(worstIndex, scratch1), get(rightIndex, scratch2))) {
351364
worst = rightChild;
352365
worstIndex = rightIndex;
353366
}

0 commit comments

Comments
 (0)