Skip to content

Commit 38441b9

Browse files
committed
Simplify RandomSampleOperator
1 parent 0063737 commit 38441b9

File tree

1 file changed

+43
-240
lines changed

1 file changed

+43
-240
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/RandomSampleOperator.java

Lines changed: 43 additions & 240 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77

88
package org.elasticsearch.compute.operator;
99

10-
import com.carrotsearch.hppc.BitMixer;
11-
12-
import org.apache.lucene.search.DocIdSetIterator;
1310
import org.elasticsearch.TransportVersion;
1411
import org.elasticsearch.TransportVersions;
1512
import org.elasticsearch.common.Strings;
@@ -22,43 +19,32 @@
2219
import org.elasticsearch.xcontent.XContentBuilder;
2320

2421
import java.io.IOException;
25-
import java.util.ArrayDeque;
26-
import java.util.ArrayList;
2722
import java.util.Arrays;
28-
import java.util.List;
23+
import java.util.Deque;
24+
import java.util.LinkedList;
2925
import java.util.Objects;
3026
import java.util.SplittableRandom;
3127

3228
public class RandomSampleOperator implements Operator {
3329

34-
// The threshold for the number of rows to collect in a batch before starting sampling it.
35-
private static final int ROWS_BATCH_THRESHOLD = 10_000;
36-
// How many batches can be to keep in memory and still accept new input Pages.
37-
// Besides these many buffered batches, the operator holds an additional batch that's being sampled.
38-
private static final int MAX_BUFFERED_BATCHES = 1;
39-
40-
private final double probability;
41-
private final int seed;
42-
43-
private boolean collecting = true;
44-
private boolean isFinished = false;
45-
private final PageBatching pageBatching;
46-
private BatchSampling currentSampling;
30+
private boolean finished;
31+
private final Deque<Page> outputPages;
32+
private final RandomSamplingQuery.RandomSamplingIterator randomSamplingIterator;
4733

4834
private int pagesCollected = 0;
4935
private int pagesEmitted = 0;
5036
private int rowsCollected = 0;
5137
private int rowsEmitted = 0;
52-
private int batchesSampled = 0;
5338

5439
private long collectNanos;
5540
private long emitNanos;
5641

5742
public RandomSampleOperator(double probability, int seed) {
58-
this.probability = probability;
59-
this.seed = seed;
60-
// TODO derive the threshold from the probability and a max cap
61-
pageBatching = new PageBatching(ROWS_BATCH_THRESHOLD, MAX_BUFFERED_BATCHES);
43+
finished = false;
44+
outputPages = new LinkedList<>();
45+
SplittableRandom random = new SplittableRandom(seed);
46+
randomSamplingIterator = new RandomSamplingQuery.RandomSamplingIterator(Integer.MAX_VALUE, probability, random::nextInt);
47+
randomSamplingIterator.nextDoc();
6248
}
6349

6450
public record Factory(double probability, int seed) implements OperatorFactory {
@@ -79,7 +65,7 @@ public String describe() {
7965
*/
8066
@Override
8167
public boolean needsInput() {
82-
return collecting && pageBatching.capacityAvailable();
68+
return finished == false;
8369
}
8470

8571
/**
@@ -91,80 +77,64 @@ public boolean needsInput() {
9177
@Override
9278
public void addInput(Page page) {
9379
final var addStart = System.nanoTime();
94-
collect(page);
80+
createOutputPage(page);
81+
rowsCollected += page.getPositionCount();
82+
pagesCollected++;
83+
page.releaseBlocks();
9584
collectNanos += System.nanoTime() - addStart;
9685
}
9786

98-
private void collect(Page page) {
99-
pagesCollected++;
100-
rowsCollected += page.getPositionCount();
101-
pageBatching.addPage(page);
87+
private void createOutputPage(Page page) {
88+
final int[] sampledPositions = new int[page.getPositionCount()];
89+
int sampledIdx = 0;
90+
for (int i = randomSamplingIterator.docID(); i - rowsCollected < page.getPositionCount(); i = randomSamplingIterator.nextDoc()) {
91+
sampledPositions[sampledIdx++] = i - rowsCollected;
92+
}
93+
if (sampledIdx > 0) {
94+
outputPages.add(page.filter(Arrays.copyOf(sampledPositions, sampledIdx)));
95+
}
10296
}
10397

10498
/**
10599
* notifies the operator that it won't receive any more input pages
106100
*/
107101
@Override
108102
public void finish() {
109-
if (collecting && rowsCollected > 0) { // finish() can be called multiple times
110-
pageBatching.flush();
111-
}
112-
collecting = false;
103+
finished = true;
113104
}
114105

115106
/**
116107
* whether the operator has finished processing all input pages and made the corresponding output pages available
117108
*/
118109
@Override
119110
public boolean isFinished() {
120-
return isFinished;
111+
return finished && outputPages.isEmpty();
121112
}
122113

123-
/**
124-
* returns non-null if output page available. Only called when isFinished() == false
125-
*
126-
* @throws UnsupportedOperationException if the operator is a {@link SinkOperator}
127-
*/
128114
@Override
129115
public Page getOutput() {
130116
final var emitStart = System.nanoTime();
131-
Page page = emit();
117+
Page page;
118+
if (outputPages.isEmpty()) {
119+
page = null;
120+
} else {
121+
page = outputPages.removeFirst();
122+
pagesEmitted++;
123+
rowsEmitted += page.getPositionCount();
124+
}
132125
emitNanos += System.nanoTime() - emitStart;
133126
return page;
134127
}
135128

136-
private Page emit() {
137-
if (currentSampling == null) {
138-
if (pageBatching.hasNext() == false) {
139-
if (collecting == false) {
140-
isFinished = true;
141-
}
142-
return null; // not enough pages on the input yet
143-
}
144-
final var currentBatch = pageBatching.next();
145-
currentSampling = new BatchSampling(currentBatch, probability, seed);
146-
batchesSampled++;
147-
}
148-
149-
final var page = currentSampling.next();
150-
if (page != null) {
151-
rowsEmitted += page.getPositionCount();
152-
pagesEmitted++;
153-
return page;
154-
} // else: current batch is exhausted
155-
156-
currentSampling.close();
157-
currentSampling = null;
158-
return emit();
159-
}
160-
161129
/**
162130
* notifies the operator that it won't be used anymore (i.e. none of the other methods called),
163131
* and its resources can be cleaned up
164132
*/
165133
@Override
166134
public void close() {
167-
pageBatching.close();
135+
for (Page page : outputPages) {
136+
page.releaseBlocks();
137+
}
168138
}
169139

170140
@Override
@@ -174,175 +144,12 @@ public String toString() {
174144

175145
@Override
176146
public Operator.Status status() {
177-
return new Status(collectNanos, emitNanos, pagesCollected, pagesEmitted, rowsCollected, rowsEmitted, batchesSampled);
178-
}
179-
180-
private static class SamplingIterator {
181-
182-
private final RandomSamplingQuery.RandomSamplingIterator samplingIterator;
183-
private int nextDoc = -1;
184-
185-
SamplingIterator(int maxDoc, double probability, int seed) {
186-
final SplittableRandom random = new SplittableRandom(BitMixer.mix(seed));
187-
samplingIterator = new RandomSamplingQuery.RandomSamplingIterator(maxDoc, probability, random::nextInt);
188-
advance();
189-
}
190-
191-
boolean hasNext() {
192-
return nextDoc != DocIdSetIterator.NO_MORE_DOCS;
193-
}
194-
195-
int next() {
196-
return nextDoc;
197-
}
198-
199-
void advance() {
200-
assert hasNext() : "No more docs to sample";
201-
nextDoc = samplingIterator.nextDoc();
202-
}
203-
}
204-
205-
private record PagesBatch(ArrayDeque<Page> batch, int rowCount) {}
206-
207-
private static class PageBatching {
208-
209-
private final int collectingRowThreshold;
210-
private final int maxBufferedBatches;
211-
212-
private final List<PagesBatch> batches = new ArrayList<>();
213-
214-
private int collectingBatchRowCount = 0;
215-
private ArrayDeque<Page> collectingBatch = new ArrayDeque<>();
216-
217-
PageBatching(int collectingRowThreshold, int maxBufferedBatches) {
218-
this.collectingRowThreshold = collectingRowThreshold;
219-
this.maxBufferedBatches = maxBufferedBatches;
220-
}
221-
222-
void addPage(Page page) {
223-
collectingBatch.add(page);
224-
collectingBatchRowCount += page.getPositionCount();
225-
if (collectingBatchRowCount >= collectingRowThreshold) {
226-
rotate();
227-
}
228-
}
229-
230-
private void rotate() {
231-
batches.add(new PagesBatch(collectingBatch, collectingBatchRowCount));
232-
collectingBatch = new ArrayDeque<>();
233-
collectingBatchRowCount = 0;
234-
}
235-
236-
boolean hasNext() {
237-
return batches.isEmpty() == false;
238-
}
239-
240-
PagesBatch next() {
241-
return batches.removeFirst();
242-
}
243-
244-
public boolean capacityAvailable() {
245-
return batches.size() < maxBufferedBatches;
246-
}
247-
248-
void flush() {
249-
while (batches.isEmpty() == false) {
250-
var batch = batches.removeFirst();
251-
collectingBatch.addAll(batch.batch);
252-
collectingBatchRowCount += batch.rowCount;
253-
}
254-
if (collectingBatch.isEmpty() == false) {
255-
rotate();
256-
}
257-
}
258-
259-
void close() {
260-
assert batches.isEmpty() : "There are still available batches";
261-
assert collectingBatch.isEmpty() : "Current batch has not been rotated";
262-
}
147+
return new Status(collectNanos, emitNanos, pagesCollected, pagesEmitted, rowsCollected, rowsEmitted);
263148
}
264149

265-
private static class BatchSampling {
266-
267-
private final ArrayDeque<Page> pagesDeque;
268-
private final SamplingIterator samplingIterator;
269-
270-
private int rowsProcessed = 0;
271-
272-
BatchSampling(PagesBatch batch, double probability, int seed) {
273-
pagesDeque = batch.batch;
274-
samplingIterator = new SamplingIterator(batch.rowCount, probability, seed);
275-
}
276-
277-
Page next() {
278-
while (pagesDeque.isEmpty() == false) {
279-
final var page = pagesDeque.poll();
280-
final int positionCount = page.getPositionCount();
281-
final int[] sampledPositions = new int[positionCount];
282-
int sampledIdx = 0;
283-
284-
while (true) {
285-
if (samplingIterator.hasNext()) {
286-
var docOffset = samplingIterator.next() - rowsProcessed;
287-
if (docOffset < positionCount) {
288-
sampledPositions[sampledIdx++] = docOffset;
289-
samplingIterator.advance();
290-
} else {
291-
// position falls outside the current page
292-
break;
293-
}
294-
} else {
295-
// no more docs to sample
296-
drainPages();
297-
break;
298-
}
299-
}
300-
rowsProcessed += positionCount;
301-
302-
if (sampledIdx > 0) {
303-
var filter = Arrays.copyOf(sampledPositions, sampledIdx);
304-
return page.filter(filter);
305-
} // else: fetch a new page (if any left)
306-
307-
releasePage(page);
308-
}
309-
310-
return null;
311-
}
312-
313-
private void drainPages() {
314-
Page page;
315-
do {
316-
page = pagesDeque.poll();
317-
} while (releasePage(page));
318-
}
319-
320-
/**
321-
* Returns true if there was a non-null page that was released.
322-
*/
323-
private static boolean releasePage(Page page) {
324-
if (page != null) {
325-
page.releaseBlocks();
326-
return true;
327-
}
328-
return false;
329-
}
330-
331-
void close() {
332-
assert pagesDeque.isEmpty() : "There are still unreleased pages";
333-
assert samplingIterator.hasNext() == false : "There are still docs to sample";
334-
}
335-
}
336-
337-
private record Status(
338-
long collectNanos,
339-
long emitNanos,
340-
int pagesCollected,
341-
int pagesEmitted,
342-
int rowsCollected,
343-
int rowsEmitted,
344-
int batchesSampled
345-
) implements Operator.Status {
150+
private record Status(long collectNanos, long emitNanos, int pagesCollected, int pagesEmitted, int rowsCollected, int rowsEmitted)
151+
implements
152+
Operator.Status {
346153

347154
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
348155
Operator.Status.class,
@@ -357,7 +164,6 @@ private record Status(
357164
streamInput.readVInt(),
358165
streamInput.readVInt(),
359166
streamInput.readVInt(),
360-
streamInput.readVInt(),
361167
streamInput.readVInt()
362168
);
363169
}
@@ -370,7 +176,6 @@ public void writeTo(StreamOutput out) throws IOException {
370176
out.writeVInt(pagesEmitted);
371177
out.writeVInt(rowsCollected);
372178
out.writeVInt(rowsEmitted);
373-
out.writeVInt(batchesSampled);
374179
}
375180

376181
@Override
@@ -393,7 +198,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
393198
builder.field("pages_emitted", pagesEmitted);
394199
builder.field("rows_collected", rowsCollected);
395200
builder.field("rows_emitted", rowsEmitted);
396-
builder.field("batches_sampled", batchesSampled);
397201
return builder.endObject();
398202
}
399203

@@ -407,13 +211,12 @@ public boolean equals(Object o) {
407211
&& pagesCollected == other.pagesCollected
408212
&& pagesEmitted == other.pagesEmitted
409213
&& rowsCollected == other.rowsCollected
410-
&& rowsEmitted == other.rowsEmitted
411-
&& batchesSampled == other.batchesSampled;
214+
&& rowsEmitted == other.rowsEmitted;
412215
}
413216

414217
@Override
415218
public int hashCode() {
416-
return Objects.hash(collectNanos, emitNanos, pagesCollected, pagesEmitted, rowsCollected, rowsEmitted, batchesSampled);
219+
return Objects.hash(collectNanos, emitNanos, pagesCollected, pagesEmitted, rowsCollected, rowsEmitted);
417220
}
418221

419222
@Override

0 commit comments

Comments
 (0)