Skip to content

Commit 9ac2d39

Browse files
committed
ES|QL random sample
Squashed commit of the following: commit dd0bd21b53abf186199dbcab17bc3a9a4bf4c9d9 Author: Jan Kuipers <[email protected]> Date: Thu Mar 27 10:22:53 2025 +0100 remove extra changelog commit 2dfd7cb Author: Jan Kuipers <[email protected]> Date: Wed Mar 26 16:43:43 2025 +0100 fix random sample csv tests commit dcaf2aa Author: Jan Kuipers <[email protected]> Date: Wed Mar 26 15:11:00 2025 +0100 spotless commit cfe0f19 Author: Jan Kuipers <[email protected]> Date: Wed Mar 26 14:56:53 2025 +0100 Refactor sample correction once more commit 6b4696e Merge: 1a6bf4d 0e0214d Author: Jan Kuipers <[email protected]> Date: Wed Mar 26 14:01:19 2025 +0100 Merge branch 'main' of github.com:elastic/elasticsearch into feat/random_sample commit 1a6bf4d Author: Jan Kuipers <[email protected]> Date: Wed Mar 26 12:41:07 2025 +0100 Refactor sample correction commit d7b9434 Author: Jan Kuipers <[email protected]> Date: Wed Mar 26 08:43:42 2025 +0100 don't correct multiple stats commit 3d6947d Author: Jan Kuipers <[email protected]> Date: Tue Mar 25 11:04:20 2025 +0100 Update docs/changelog/125570.yaml commit 8f55e07 Author: Jan Kuipers <[email protected]> Date: Tue Mar 18 12:35:10 2025 +0100 correct aggregations for random sampling commit 38441b9 Author: Jan Kuipers <[email protected]> Date: Tue Mar 25 09:57:58 2025 +0100 Simplify RandomSampleOperator commit 0063737 Author: Bogdan Pintea <[email protected]> Date: Thu Mar 6 13:21:49 2025 +0100 Make CsvTests more node-count-induced variation tollerant commit 8fbc684 Merge: 213634e 47706b5 Author: Bogdan Pintea <[email protected]> Date: Wed Mar 5 22:06:47 2025 +0100 Merge remote-tracking branch 'upstream/main' into feat/random_sample commit 213634e Author: elasticsearchmachine <[email protected]> Date: Wed Mar 5 19:49:18 2025 +0000 [CI] Auto commit changes from spotless commit e5de4bf Merge: 0f9600c a92b1d6 Author: Bogdan Pintea <[email protected]> Date: Wed Mar 5 20:40:16 2025 +0100 Merge remote-tracking branch 'upstream/main' into feat/random_sample commit 0f9600c Author: Bogdan Pintea <[email protected]> Date: Wed Mar 5 20:29:05 2025 +0100 Make seed parameter optional. Various fixes commit 3cbd508 Author: Bogdan Pintea <[email protected]> Date: Wed Mar 5 16:17:07 2025 +0100 Add non-operator-related tests commit 15afc08 Author: elasticsearchmachine <[email protected]> Date: Mon Mar 3 14:15:57 2025 +0000 [CI] Auto commit changes from spotless commit 7abf28d Author: Bogdan Pintea <[email protected]> Date: Mon Mar 3 15:06:44 2025 +0100 Update docs/changelog/123879.yaml commit 36acb35 Author: Bogdan Pintea <[email protected]> Date: Mon Mar 3 15:00:42 2025 +0100 Add a random sample commadn This adds RANDOM_SAMPLE <probability> <seed>? command.
1 parent 8ced682 commit 9ac2d39

File tree

53 files changed

+4151
-2512
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+4151
-2512
lines changed

docs/changelog/125570.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 125570
2+
summary: ES|QL random sampling
3+
area: Machine Learning
4+
type: feature
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ static TransportVersion def(int id) {
206206
public static final TransportVersion ESQL_REPORT_ORIGINAL_TYPES = def(9_038_00_0);
207207
public static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO = def(9_039_0_00);
208208
public static final TransportVersion PROJECT_ID_IN_SNAPSHOT = def(9_040_0_00);
209+
public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER = def(9_041_0_00);
209210

210211
/*
211212
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/search/SearchModule.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@
134134
import org.elasticsearch.search.aggregations.bucket.sampler.UnmappedSampler;
135135
import org.elasticsearch.search.aggregations.bucket.sampler.random.InternalRandomSampler;
136136
import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplerAggregationBuilder;
137+
import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQueryBuilder;
137138
import org.elasticsearch.search.aggregations.bucket.terms.DoubleTerms;
138139
import org.elasticsearch.search.aggregations.bucket.terms.LongRareTerms;
139140
import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
@@ -1186,6 +1187,9 @@ private void registerQueryParsers(List<SearchPlugin> plugins) {
11861187
registerQuery(new QuerySpec<>(ExactKnnQueryBuilder.NAME, ExactKnnQueryBuilder::new, parser -> {
11871188
throw new IllegalArgumentException("[exact_knn] queries cannot be provided directly");
11881189
}));
1190+
registerQuery(
1191+
new QuerySpec<>(RandomSamplingQueryBuilder.NAME, RandomSamplingQueryBuilder::new, RandomSamplingQueryBuilder::fromXContent)
1192+
);
11891193

11901194
registerFromPlugin(plugins, SearchPlugin::getQueries, this::registerQuery);
11911195
}

server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplingQuery.java

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,34 @@ public final class RandomSamplingQuery extends Query {
4444
* can be generated
4545
*/
4646
public RandomSamplingQuery(double p, int seed, int hash) {
47-
if (p <= 0.0 || p >= 1.0) {
48-
throw new IllegalArgumentException("RandomSampling probability must be between 0.0 and 1.0, was [" + p + "]");
49-
}
47+
checkProbabilityRange(p);
5048
this.p = p;
5149
this.seed = seed;
5250
this.hash = hash;
5351
}
5452

53+
/**
54+
* Verifies that the probability is within the (0.0, 1.0) range.
55+
* @throws IllegalArgumentException in case of an invalid probability.
56+
*/
57+
public static void checkProbabilityRange(double p) throws IllegalArgumentException {
58+
if (p <= 0.0 || p >= 1.0) {
59+
throw new IllegalArgumentException("RandomSampling probability must be strictly between 0.0 and 1.0, was [" + p + "]");
60+
}
61+
}
62+
63+
public double probability() {
64+
return p;
65+
}
66+
67+
public int seed() {
68+
return seed;
69+
}
70+
71+
public int hash() {
72+
return hash;
73+
}
74+
5575
@Override
5676
public String toString(String field) {
5777
return "RandomSamplingQuery{" + "p=" + p + ", seed=" + seed + ", hash=" + hash + '}';
@@ -98,13 +118,13 @@ public void visit(QueryVisitor visitor) {
98118
/**
99119
* A DocIDSetIter that skips a geometrically random number of documents
100120
*/
101-
static class RandomSamplingIterator extends DocIdSetIterator {
121+
public static class RandomSamplingIterator extends DocIdSetIterator {
102122
private final int maxDoc;
103123
private final double p;
104124
private final FastGeometric distribution;
105125
private int doc = -1;
106126

107-
RandomSamplingIterator(int maxDoc, double p, IntSupplier rng) {
127+
public RandomSamplingIterator(int maxDoc, double p, IntSupplier rng) {
108128
this.maxDoc = maxDoc;
109129
this.p = p;
110130
this.distribution = new FastGeometric(rng, p);
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.aggregations.bucket.sampler.random;
11+
12+
import org.apache.lucene.search.Query;
13+
import org.elasticsearch.TransportVersion;
14+
import org.elasticsearch.TransportVersions;
15+
import org.elasticsearch.common.Randomness;
16+
import org.elasticsearch.common.io.stream.StreamInput;
17+
import org.elasticsearch.common.io.stream.StreamOutput;
18+
import org.elasticsearch.index.query.AbstractQueryBuilder;
19+
import org.elasticsearch.index.query.SearchExecutionContext;
20+
import org.elasticsearch.xcontent.ConstructingObjectParser;
21+
import org.elasticsearch.xcontent.ParseField;
22+
import org.elasticsearch.xcontent.XContentBuilder;
23+
import org.elasticsearch.xcontent.XContentParser;
24+
25+
import java.io.IOException;
26+
import java.util.Objects;
27+
28+
import static org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQuery.checkProbabilityRange;
29+
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
30+
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
31+
32+
public class RandomSamplingQueryBuilder extends AbstractQueryBuilder<RandomSamplingQueryBuilder> {
33+
34+
public static final String NAME = "random_sampling";
35+
static final ParseField PROBABILITY = new ParseField("query");
36+
static final ParseField SEED = new ParseField("seed");
37+
static final ParseField HASH = new ParseField("hash");
38+
39+
private final double probability;
40+
private int seed = Randomness.get().nextInt();
41+
private int hash = 0;
42+
43+
public RandomSamplingQueryBuilder(double probability) {
44+
checkProbabilityRange(probability);
45+
this.probability = probability;
46+
}
47+
48+
public RandomSamplingQueryBuilder seed(int seed) {
49+
checkProbabilityRange(probability);
50+
this.seed = seed;
51+
return this;
52+
}
53+
54+
public RandomSamplingQueryBuilder(StreamInput in) throws IOException {
55+
super(in);
56+
this.probability = in.readDouble();
57+
this.seed = in.readInt();
58+
this.hash = in.readInt();
59+
}
60+
61+
public RandomSamplingQueryBuilder hash(Integer hash) {
62+
this.hash = hash;
63+
return this;
64+
}
65+
66+
public double probability() {
67+
return probability;
68+
}
69+
70+
public int seed() {
71+
return seed;
72+
}
73+
74+
public int hash() {
75+
return hash;
76+
}
77+
78+
@Override
79+
protected void doWriteTo(StreamOutput out) throws IOException {
80+
out.writeDouble(probability);
81+
out.writeInt(seed);
82+
out.writeInt(hash);
83+
}
84+
85+
@Override
86+
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
87+
builder.startObject(NAME);
88+
builder.field(PROBABILITY.getPreferredName(), probability);
89+
builder.field(SEED.getPreferredName(), seed);
90+
builder.field(HASH.getPreferredName(), hash);
91+
builder.endObject();
92+
}
93+
94+
private static final ConstructingObjectParser<RandomSamplingQueryBuilder, Void> PARSER = new ConstructingObjectParser<>(
95+
NAME,
96+
false,
97+
args -> {
98+
var randomSamplingQueryBuilder = new RandomSamplingQueryBuilder((double) args[0]);
99+
if (args[1] != null) {
100+
randomSamplingQueryBuilder.seed((int) args[1]);
101+
}
102+
if (args[2] != null) {
103+
randomSamplingQueryBuilder.hash((int) args[2]);
104+
}
105+
return randomSamplingQueryBuilder;
106+
}
107+
);
108+
109+
static {
110+
PARSER.declareDouble(constructorArg(), PROBABILITY);
111+
PARSER.declareInt(optionalConstructorArg(), SEED);
112+
PARSER.declareInt(optionalConstructorArg(), HASH);
113+
}
114+
115+
public static RandomSamplingQueryBuilder fromXContent(XContentParser parser) throws IOException {
116+
return PARSER.apply(parser, null);
117+
}
118+
119+
@Override
120+
protected Query doToQuery(SearchExecutionContext context) throws IOException {
121+
return new RandomSamplingQuery(probability, seed, hash);
122+
}
123+
124+
@Override
125+
protected boolean doEquals(RandomSamplingQueryBuilder other) {
126+
return probability == other.probability && seed == other.seed && hash == other.hash;
127+
}
128+
129+
@Override
130+
protected int doHashCode() {
131+
return Objects.hash(probability, seed, hash);
132+
}
133+
134+
/**
135+
* Returns the name of the writeable object
136+
*/
137+
@Override
138+
public String getWriteableName() {
139+
return NAME;
140+
}
141+
142+
/**
143+
* The minimal version of the recipient this object can be sent to
144+
*/
145+
@Override
146+
public TransportVersion getMinimalSupportedVersion() {
147+
return TransportVersions.RANDOM_SAMPLER_QUERY_BUILDER;
148+
}
149+
}

server/src/test/java/org/elasticsearch/search/SearchModuleTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ public CheckedBiConsumer<ShardSearchRequest, StreamOutput, IOException> getReque
444444
"range",
445445
"regexp",
446446
"knn_score_doc",
447+
"random_sampling",
447448
"script",
448449
"script_score",
449450
"simple_query_string",
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.aggregations.bucket.sampler.random;
11+
12+
import org.apache.lucene.search.Query;
13+
import org.elasticsearch.index.query.SearchExecutionContext;
14+
import org.elasticsearch.test.AbstractQueryTestCase;
15+
import org.elasticsearch.xcontent.XContentParseException;
16+
17+
import java.io.IOException;
18+
19+
import static org.hamcrest.Matchers.equalTo;
20+
21+
public class RandomSamplingQueryBuilderTests extends AbstractQueryTestCase<RandomSamplingQueryBuilder> {
22+
23+
@Override
24+
protected RandomSamplingQueryBuilder doCreateTestQueryBuilder() {
25+
double probability = randomDoubleBetween(0.0, 1.0, false);
26+
var builder = new RandomSamplingQueryBuilder(probability);
27+
if (randomBoolean()) {
28+
builder.seed(randomInt());
29+
}
30+
if (randomBoolean()) {
31+
builder.hash(randomInt());
32+
}
33+
return builder;
34+
}
35+
36+
@Override
37+
protected void doAssertLuceneQuery(RandomSamplingQueryBuilder queryBuilder, Query query, SearchExecutionContext context)
38+
throws IOException {
39+
var rsQuery = asInstanceOf(RandomSamplingQuery.class, query);
40+
assertThat(rsQuery.probability(), equalTo(queryBuilder.probability()));
41+
assertThat(rsQuery.seed(), equalTo(queryBuilder.seed()));
42+
assertThat(rsQuery.hash(), equalTo(queryBuilder.hash()));
43+
}
44+
45+
@Override
46+
protected boolean supportsBoost() {
47+
return false;
48+
}
49+
50+
@Override
51+
protected boolean supportsQueryName() {
52+
return false;
53+
}
54+
55+
@Override
56+
public void testUnknownField() {
57+
var json = "{ \""
58+
+ RandomSamplingQueryBuilder.NAME
59+
+ "\" : {\"bogusField\" : \"someValue\", \""
60+
+ RandomSamplingQueryBuilder.PROBABILITY.getPreferredName()
61+
+ "\" : \""
62+
+ randomBoolean()
63+
+ "\", \""
64+
+ RandomSamplingQueryBuilder.SEED.getPreferredName()
65+
+ "\" : \""
66+
+ randomInt()
67+
+ "\", \""
68+
+ RandomSamplingQueryBuilder.HASH.getPreferredName()
69+
+ "\" : \""
70+
+ randomInt()
71+
+ "\" } }";
72+
var e = expectThrows(XContentParseException.class, () -> parseQuery(json));
73+
assertTrue(e.getMessage().contains("bogusField"));
74+
}
75+
}

x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MetadataAttribute.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ public static boolean isSupported(String name) {
172172
return ATTRIBUTES_MAP.containsKey(name);
173173
}
174174

175+
public static boolean isScoreAttribute(Expression a) {
176+
return a instanceof MetadataAttribute ma && ma.name().equals(SCORE);
177+
}
178+
175179
@Override
176180
public boolean equals(Object obj) {
177181
if (false == super.equals(obj)) {

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/Page.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,4 +294,21 @@ public Page projectBlocks(int[] blockMapping) {
294294
}
295295
}
296296
}
297+
298+
public Page filter(int... positions) {
299+
Block[] filteredBlocks = new Block[blocks.length];
300+
boolean success = false;
301+
try {
302+
for (int i = 0; i < blocks.length; i++) {
303+
filteredBlocks[i] = getBlock(i).filter(positions);
304+
}
305+
success = true;
306+
} finally {
307+
releaseBlocks();
308+
if (success == false) {
309+
Releasables.closeExpectNoException(filteredBlocks);
310+
}
311+
}
312+
return new Page(filteredBlocks);
313+
}
297314
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FilterOperator.java

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.compute.operator;
99

10-
import org.elasticsearch.compute.data.Block;
1110
import org.elasticsearch.compute.data.BooleanBlock;
1211
import org.elasticsearch.compute.data.Page;
1312
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
@@ -69,20 +68,7 @@ protected Page process(Page page) {
6968
}
7069
positions = Arrays.copyOf(positions, rowCount);
7170

72-
Block[] filteredBlocks = new Block[page.getBlockCount()];
73-
boolean success = false;
74-
try {
75-
for (int i = 0; i < page.getBlockCount(); i++) {
76-
filteredBlocks[i] = page.getBlock(i).filter(positions);
77-
}
78-
success = true;
79-
} finally {
80-
page.releaseBlocks();
81-
if (success == false) {
82-
Releasables.closeExpectNoException(filteredBlocks);
83-
}
84-
}
85-
return new Page(filteredBlocks);
71+
return page.filter(positions);
8672
}
8773
}
8874

0 commit comments

Comments
 (0)