Skip to content

Commit bd82576

Browse files
committed
SampleOperaetorTests + fix status
1 parent 20ca56c commit bd82576

File tree

2 files changed

+72
-31
lines changed

2 files changed

+72
-31
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/SampleOperator.java

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,8 @@ public class SampleOperator implements Operator {
3131
private final Deque<Page> outputPages;
3232
private final RandomSamplingQuery.RandomSamplingIterator randomSamplingIterator;
3333

34-
private int pagesCollected = 0;
35-
private int pagesEmitted = 0;
36-
private int rowsCollected = 0;
34+
private int pagesProcessed = 0;
35+
private int rowsReceived = 0;
3736
private int rowsEmitted = 0;
3837

3938
private long collectNanos;
@@ -76,19 +75,19 @@ public boolean needsInput() {
7675
*/
7776
@Override
7877
public void addInput(Page page) {
79-
final var addStart = System.nanoTime();
78+
long startTime = System.nanoTime();
8079
createOutputPage(page);
81-
rowsCollected += page.getPositionCount();
82-
pagesCollected++;
80+
rowsReceived += page.getPositionCount();
8381
page.releaseBlocks();
84-
collectNanos += System.nanoTime() - addStart;
82+
pagesProcessed++;
83+
collectNanos += System.nanoTime() - startTime;
8584
}
8685

8786
private void createOutputPage(Page page) {
8887
final int[] sampledPositions = new int[page.getPositionCount()];
8988
int sampledIdx = 0;
90-
for (int i = randomSamplingIterator.docID(); i - rowsCollected < page.getPositionCount(); i = randomSamplingIterator.nextDoc()) {
91-
sampledPositions[sampledIdx++] = i - rowsCollected;
89+
for (int i = randomSamplingIterator.docID(); i - rowsReceived < page.getPositionCount(); i = randomSamplingIterator.nextDoc()) {
90+
sampledPositions[sampledIdx++] = i - rowsReceived;
9291
}
9392
if (sampledIdx > 0) {
9493
outputPages.add(page.filter(Arrays.copyOf(sampledPositions, sampledIdx)));
@@ -119,7 +118,6 @@ public Page getOutput() {
119118
page = null;
120119
} else {
121120
page = outputPages.removeFirst();
122-
pagesEmitted++;
123121
rowsEmitted += page.getPositionCount();
124122
}
125123
emitNanos += System.nanoTime() - emitStart;
@@ -139,15 +137,15 @@ public void close() {
139137

140138
@Override
141139
public String toString() {
142-
return "SampleOperator[sampled = " + rowsEmitted + "/" + rowsCollected + "]";
140+
return "SampleOperator[sampled = " + rowsEmitted + "/" + rowsReceived + "]";
143141
}
144142

145143
@Override
146144
public Operator.Status status() {
147-
return new Status(collectNanos, emitNanos, pagesCollected, pagesEmitted, rowsCollected, rowsEmitted);
145+
return new Status(collectNanos, emitNanos, pagesProcessed, rowsReceived, rowsEmitted);
148146
}
149147

150-
private record Status(long collectNanos, long emitNanos, int pagesCollected, int pagesEmitted, int rowsCollected, int rowsEmitted)
148+
private record Status(long collectNanos, long emitNanos, int pagesProcessed, int rowsReceived, int rowsEmitted)
151149
implements
152150
Operator.Status {
153151

@@ -158,23 +156,15 @@ private record Status(long collectNanos, long emitNanos, int pagesCollected, int
158156
);
159157

160158
Status(StreamInput streamInput) throws IOException {
161-
this(
162-
streamInput.readVLong(),
163-
streamInput.readVLong(),
164-
streamInput.readVInt(),
165-
streamInput.readVInt(),
166-
streamInput.readVInt(),
167-
streamInput.readVInt()
168-
);
159+
this(streamInput.readVLong(), streamInput.readVLong(), streamInput.readVInt(), streamInput.readVInt(), streamInput.readVInt());
169160
}
170161

171162
@Override
172163
public void writeTo(StreamOutput out) throws IOException {
173164
out.writeVLong(collectNanos);
174165
out.writeVLong(emitNanos);
175-
out.writeVInt(pagesCollected);
176-
out.writeVInt(pagesEmitted);
177-
out.writeVInt(rowsCollected);
166+
out.writeVInt(pagesProcessed);
167+
out.writeVInt(rowsReceived);
178168
out.writeVInt(rowsEmitted);
179169
}
180170

@@ -194,9 +184,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
194184
if (builder.humanReadable()) {
195185
builder.field("emit_time", TimeValue.timeValueNanos(emitNanos));
196186
}
197-
builder.field("pages_collected", pagesCollected);
198-
builder.field("pages_emitted", pagesEmitted);
199-
builder.field("rows_collected", rowsCollected);
187+
builder.field("pages_processed", pagesProcessed);
188+
builder.field("rows_received", rowsReceived);
200189
builder.field("rows_emitted", rowsEmitted);
201190
return builder.endObject();
202191
}
@@ -208,15 +197,14 @@ public boolean equals(Object o) {
208197
Status other = (Status) o;
209198
return collectNanos == other.collectNanos
210199
&& emitNanos == other.emitNanos
211-
&& pagesCollected == other.pagesCollected
212-
&& pagesEmitted == other.pagesEmitted
213-
&& rowsCollected == other.rowsCollected
200+
&& pagesProcessed == other.pagesProcessed
201+
&& rowsReceived == other.rowsReceived
214202
&& rowsEmitted == other.rowsEmitted;
215203
}
216204

217205
@Override
218206
public int hashCode() {
219-
return Objects.hash(collectNanos, emitNanos, pagesCollected, pagesEmitted, rowsCollected, rowsEmitted);
207+
return Objects.hash(collectNanos, emitNanos, pagesProcessed, rowsReceived, rowsEmitted);
220208
}
221209

222210
@Override
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.operator;
9+
10+
import org.elasticsearch.compute.data.BlockFactory;
11+
import org.elasticsearch.compute.data.Page;
12+
import org.elasticsearch.compute.test.OperatorTestCase;
13+
import org.elasticsearch.compute.test.SequenceLongBlockSourceOperator;
14+
import org.hamcrest.Matcher;
15+
16+
import java.util.List;
17+
import java.util.stream.LongStream;
18+
19+
import static org.hamcrest.Matchers.closeTo;
20+
import static org.hamcrest.Matchers.equalTo;
21+
import static org.hamcrest.Matchers.matchesPattern;
22+
23+
public class SampleOperatorTests extends OperatorTestCase {
24+
25+
@Override
26+
protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
27+
return new SequenceLongBlockSourceOperator(blockFactory, LongStream.range(0, size));
28+
}
29+
30+
@Override
31+
protected void assertSimpleOutput(List<Page> input, List<Page> results) {
32+
int inputCount = input.stream().mapToInt(Page::getPositionCount).sum();
33+
int outputCount = results.stream().mapToInt(Page::getPositionCount).sum();
34+
double meanExpectedOutputCount = 0.5 * inputCount;
35+
double stdDevExpectedOutputCount = Math.sqrt(meanExpectedOutputCount);
36+
assertThat((double) outputCount, closeTo(meanExpectedOutputCount, 10 * stdDevExpectedOutputCount));
37+
}
38+
39+
@Override
40+
protected Operator.OperatorFactory simple() {
41+
return new SampleOperator.Factory(0.5, randomInt());
42+
}
43+
44+
@Override
45+
protected Matcher<String> expectedDescriptionOfSimple() {
46+
return matchesPattern("SampleOperator\\[probability = 0.5, seed = -?\\d+]");
47+
}
48+
49+
@Override
50+
protected Matcher<String> expectedToStringOfSimple() {
51+
return equalTo("SampleOperator[sampled = 0/0]");
52+
}
53+
}

0 commit comments

Comments
 (0)