77
88package org .elasticsearch .compute .operator ;
99
10- import com .carrotsearch .hppc .BitMixer ;
11-
12- import org .apache .lucene .search .DocIdSetIterator ;
1310import org .elasticsearch .TransportVersion ;
1411import org .elasticsearch .TransportVersions ;
1512import org .elasticsearch .common .Strings ;
2219import org .elasticsearch .xcontent .XContentBuilder ;
2320
2421import java .io .IOException ;
25- import java .util .ArrayDeque ;
26- import java .util .ArrayList ;
2722import java .util .Arrays ;
28- import java .util .List ;
23+ import java .util .Deque ;
24+ import java .util .LinkedList ;
2925import java .util .Objects ;
3026import java .util .SplittableRandom ;
3127
3228public class RandomSampleOperator implements Operator {
3329
34- // The threshold for the number of rows to collect in a batch before starting sampling it.
35- private static final int ROWS_BATCH_THRESHOLD = 10_000 ;
36- // How many batches can be to keep in memory and still accept new input Pages.
37- // Besides these many buffered batches, the operator holds an additional batch that's being sampled.
38- private static final int MAX_BUFFERED_BATCHES = 1 ;
39-
40- private final double probability ;
41- private final int seed ;
42-
43- private boolean collecting = true ;
44- private boolean isFinished = false ;
45- private final PageBatching pageBatching ;
46- private BatchSampling currentSampling ;
30+ private boolean finished ;
31+ private final Deque <Page > outputPages ;
32+ private final RandomSamplingQuery .RandomSamplingIterator randomSamplingIterator ;
4733
4834 private int pagesCollected = 0 ;
4935 private int pagesEmitted = 0 ;
5036 private int rowsCollected = 0 ;
5137 private int rowsEmitted = 0 ;
52- private int batchesSampled = 0 ;
5338
5439 private long collectNanos ;
5540 private long emitNanos ;
5641
5742 public RandomSampleOperator (double probability , int seed ) {
58- this .probability = probability ;
59- this .seed = seed ;
60- // TODO derive the threshold from the probability and a max cap
61- pageBatching = new PageBatching (ROWS_BATCH_THRESHOLD , MAX_BUFFERED_BATCHES );
43+ finished = false ;
44+ outputPages = new LinkedList <>();
45+ SplittableRandom random = new SplittableRandom (seed );
46+ randomSamplingIterator = new RandomSamplingQuery .RandomSamplingIterator (Integer .MAX_VALUE , probability , random ::nextInt );
47+ randomSamplingIterator .nextDoc ();
6248 }
6349
6450 public record Factory (double probability , int seed ) implements OperatorFactory {
@@ -79,7 +65,7 @@ public String describe() {
7965 */
8066 @ Override
8167 public boolean needsInput () {
82- return collecting && pageBatching . capacityAvailable () ;
68+ return finished == false ;
8369 }
8470
8571 /**
@@ -91,80 +77,64 @@ public boolean needsInput() {
9177 @ Override
9278 public void addInput (Page page ) {
9379 final var addStart = System .nanoTime ();
94- collect (page );
80+ createOutputPage (page );
81+ rowsCollected += page .getPositionCount ();
82+ pagesCollected ++;
83+ page .releaseBlocks ();
9584 collectNanos += System .nanoTime () - addStart ;
9685 }
9786
98- private void collect (Page page ) {
99- pagesCollected ++;
100- rowsCollected += page .getPositionCount ();
101- pageBatching .addPage (page );
87+ private void createOutputPage (Page page ) {
88+ final int [] sampledPositions = new int [page .getPositionCount ()];
89+ int sampledIdx = 0 ;
90+ for (int i = randomSamplingIterator .docID (); i - rowsCollected < page .getPositionCount (); i = randomSamplingIterator .nextDoc ()) {
91+ sampledPositions [sampledIdx ++] = i - rowsCollected ;
92+ }
93+ if (sampledIdx > 0 ) {
94+ outputPages .add (page .filter (Arrays .copyOf (sampledPositions , sampledIdx )));
95+ }
10296 }
10397
10498 /**
10599 * notifies the operator that it won't receive any more input pages
106100 */
107101 @ Override
108102 public void finish () {
109- if (collecting && rowsCollected > 0 ) { // finish() can be called multiple times
110- pageBatching .flush ();
111- }
112- collecting = false ;
103+ finished = true ;
113104 }
114105
115106 /**
116107 * whether the operator has finished processing all input pages and made the corresponding output pages available
117108 */
118109 @ Override
119110 public boolean isFinished () {
120- return isFinished ;
111+ return finished && outputPages . isEmpty () ;
121112 }
122113
123- /**
124- * returns non-null if output page available. Only called when isFinished() == false
125- *
126- * @throws UnsupportedOperationException if the operator is a {@link SinkOperator}
127- */
128114 @ Override
129115 public Page getOutput () {
130116 final var emitStart = System .nanoTime ();
131- Page page = emit ();
117+ Page page ;
118+ if (outputPages .isEmpty ()) {
119+ page = null ;
120+ } else {
121+ page = outputPages .removeFirst ();
122+ pagesEmitted ++;
123+ rowsEmitted += page .getPositionCount ();
124+ }
132125 emitNanos += System .nanoTime () - emitStart ;
133126 return page ;
134127 }
135128
136- private Page emit () {
137- if (currentSampling == null ) {
138- if (pageBatching .hasNext () == false ) {
139- if (collecting == false ) {
140- isFinished = true ;
141- }
142- return null ; // not enough pages on the input yet
143- }
144- final var currentBatch = pageBatching .next ();
145- currentSampling = new BatchSampling (currentBatch , probability , seed );
146- batchesSampled ++;
147- }
148-
149- final var page = currentSampling .next ();
150- if (page != null ) {
151- rowsEmitted += page .getPositionCount ();
152- pagesEmitted ++;
153- return page ;
154- } // else: current batch is exhausted
155-
156- currentSampling .close ();
157- currentSampling = null ;
158- return emit ();
159- }
160-
161129 /**
162130 * notifies the operator that it won't be used anymore (i.e. none of the other methods called),
163131 * and its resources can be cleaned up
164132 */
165133 @ Override
166134 public void close () {
167- pageBatching .close ();
135+ for (Page page : outputPages ) {
136+ page .releaseBlocks ();
137+ }
168138 }
169139
170140 @ Override
@@ -174,175 +144,12 @@ public String toString() {
174144
175145 @ Override
176146 public Operator .Status status () {
177- return new Status (collectNanos , emitNanos , pagesCollected , pagesEmitted , rowsCollected , rowsEmitted , batchesSampled );
178- }
179-
180- private static class SamplingIterator {
181-
182- private final RandomSamplingQuery .RandomSamplingIterator samplingIterator ;
183- private int nextDoc = -1 ;
184-
185- SamplingIterator (int maxDoc , double probability , int seed ) {
186- final SplittableRandom random = new SplittableRandom (BitMixer .mix (seed ));
187- samplingIterator = new RandomSamplingQuery .RandomSamplingIterator (maxDoc , probability , random ::nextInt );
188- advance ();
189- }
190-
191- boolean hasNext () {
192- return nextDoc != DocIdSetIterator .NO_MORE_DOCS ;
193- }
194-
195- int next () {
196- return nextDoc ;
197- }
198-
199- void advance () {
200- assert hasNext () : "No more docs to sample" ;
201- nextDoc = samplingIterator .nextDoc ();
202- }
203- }
204-
205- private record PagesBatch (ArrayDeque <Page > batch , int rowCount ) {}
206-
207- private static class PageBatching {
208-
209- private final int collectingRowThreshold ;
210- private final int maxBufferedBatches ;
211-
212- private final List <PagesBatch > batches = new ArrayList <>();
213-
214- private int collectingBatchRowCount = 0 ;
215- private ArrayDeque <Page > collectingBatch = new ArrayDeque <>();
216-
217- PageBatching (int collectingRowThreshold , int maxBufferedBatches ) {
218- this .collectingRowThreshold = collectingRowThreshold ;
219- this .maxBufferedBatches = maxBufferedBatches ;
220- }
221-
222- void addPage (Page page ) {
223- collectingBatch .add (page );
224- collectingBatchRowCount += page .getPositionCount ();
225- if (collectingBatchRowCount >= collectingRowThreshold ) {
226- rotate ();
227- }
228- }
229-
230- private void rotate () {
231- batches .add (new PagesBatch (collectingBatch , collectingBatchRowCount ));
232- collectingBatch = new ArrayDeque <>();
233- collectingBatchRowCount = 0 ;
234- }
235-
236- boolean hasNext () {
237- return batches .isEmpty () == false ;
238- }
239-
240- PagesBatch next () {
241- return batches .removeFirst ();
242- }
243-
244- public boolean capacityAvailable () {
245- return batches .size () < maxBufferedBatches ;
246- }
247-
248- void flush () {
249- while (batches .isEmpty () == false ) {
250- var batch = batches .removeFirst ();
251- collectingBatch .addAll (batch .batch );
252- collectingBatchRowCount += batch .rowCount ;
253- }
254- if (collectingBatch .isEmpty () == false ) {
255- rotate ();
256- }
257- }
258-
259- void close () {
260- assert batches .isEmpty () : "There are still available batches" ;
261- assert collectingBatch .isEmpty () : "Current batch has not been rotated" ;
262- }
147+ return new Status (collectNanos , emitNanos , pagesCollected , pagesEmitted , rowsCollected , rowsEmitted );
263148 }
264149
265- private static class BatchSampling {
266-
267- private final ArrayDeque <Page > pagesDeque ;
268- private final SamplingIterator samplingIterator ;
269-
270- private int rowsProcessed = 0 ;
271-
272- BatchSampling (PagesBatch batch , double probability , int seed ) {
273- pagesDeque = batch .batch ;
274- samplingIterator = new SamplingIterator (batch .rowCount , probability , seed );
275- }
276-
277- Page next () {
278- while (pagesDeque .isEmpty () == false ) {
279- final var page = pagesDeque .poll ();
280- final int positionCount = page .getPositionCount ();
281- final int [] sampledPositions = new int [positionCount ];
282- int sampledIdx = 0 ;
283-
284- while (true ) {
285- if (samplingIterator .hasNext ()) {
286- var docOffset = samplingIterator .next () - rowsProcessed ;
287- if (docOffset < positionCount ) {
288- sampledPositions [sampledIdx ++] = docOffset ;
289- samplingIterator .advance ();
290- } else {
291- // position falls outside the current page
292- break ;
293- }
294- } else {
295- // no more docs to sample
296- drainPages ();
297- break ;
298- }
299- }
300- rowsProcessed += positionCount ;
301-
302- if (sampledIdx > 0 ) {
303- var filter = Arrays .copyOf (sampledPositions , sampledIdx );
304- return page .filter (filter );
305- } // else: fetch a new page (if any left)
306-
307- releasePage (page );
308- }
309-
310- return null ;
311- }
312-
313- private void drainPages () {
314- Page page ;
315- do {
316- page = pagesDeque .poll ();
317- } while (releasePage (page ));
318- }
319-
320- /**
321- * Returns true if there was a non-null page that was released.
322- */
323- private static boolean releasePage (Page page ) {
324- if (page != null ) {
325- page .releaseBlocks ();
326- return true ;
327- }
328- return false ;
329- }
330-
331- void close () {
332- assert pagesDeque .isEmpty () : "There are still unreleased pages" ;
333- assert samplingIterator .hasNext () == false : "There are still docs to sample" ;
334- }
335- }
336-
337- private record Status (
338- long collectNanos ,
339- long emitNanos ,
340- int pagesCollected ,
341- int pagesEmitted ,
342- int rowsCollected ,
343- int rowsEmitted ,
344- int batchesSampled
345- ) implements Operator .Status {
150+ private record Status (long collectNanos , long emitNanos , int pagesCollected , int pagesEmitted , int rowsCollected , int rowsEmitted )
151+ implements
152+ Operator .Status {
346153
347154 public static final NamedWriteableRegistry .Entry ENTRY = new NamedWriteableRegistry .Entry (
348155 Operator .Status .class ,
@@ -357,7 +164,6 @@ private record Status(
357164 streamInput .readVInt (),
358165 streamInput .readVInt (),
359166 streamInput .readVInt (),
360- streamInput .readVInt (),
361167 streamInput .readVInt ()
362168 );
363169 }
@@ -370,7 +176,6 @@ public void writeTo(StreamOutput out) throws IOException {
370176 out .writeVInt (pagesEmitted );
371177 out .writeVInt (rowsCollected );
372178 out .writeVInt (rowsEmitted );
373- out .writeVInt (batchesSampled );
374179 }
375180
376181 @ Override
@@ -393,7 +198,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
393198 builder .field ("pages_emitted" , pagesEmitted );
394199 builder .field ("rows_collected" , rowsCollected );
395200 builder .field ("rows_emitted" , rowsEmitted );
396- builder .field ("batches_sampled" , batchesSampled );
397201 return builder .endObject ();
398202 }
399203
@@ -407,13 +211,12 @@ public boolean equals(Object o) {
407211 && pagesCollected == other .pagesCollected
408212 && pagesEmitted == other .pagesEmitted
409213 && rowsCollected == other .rowsCollected
410- && rowsEmitted == other .rowsEmitted
411- && batchesSampled == other .batchesSampled ;
214+ && rowsEmitted == other .rowsEmitted ;
412215 }
413216
414217 @ Override
415218 public int hashCode () {
416- return Objects .hash (collectNanos , emitNanos , pagesCollected , pagesEmitted , rowsCollected , rowsEmitted , batchesSampled );
219+ return Objects .hash (collectNanos , emitNanos , pagesCollected , pagesEmitted , rowsCollected , rowsEmitted );
417220 }
418221
419222 @ Override
0 commit comments