Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/125570.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 125570
summary: ES|QL random sampling
area: Machine Learning
type: feature
issues: []
5 changes: 0 additions & 5 deletions docs/changelog/129797.yaml

This file was deleted.

2 changes: 2 additions & 0 deletions docs/reference/esql/esql-commands.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ endif::[]
* experimental:[] <<esql-lookup-join>>
* experimental:[] <<esql-mv_expand>>
* <<esql-rename>>
* experimental:[] <<esql-sample>>
* <<esql-sort>>
* <<esql-stats-by>>
* <<esql-where>>
Expand All @@ -70,6 +71,7 @@ include::processing-commands/limit.asciidoc[]
include::processing-commands/lookup.asciidoc[]
include::processing-commands/mv_expand.asciidoc[]
include::processing-commands/rename.asciidoc[]
include::processing-commands/sample.asciidoc[]
include::processing-commands/sort.asciidoc[]
include::processing-commands/stats.asciidoc[]
include::processing-commands/where.asciidoc[]

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion docs/reference/esql/functions/kibana/docs/date_trunc.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions docs/reference/esql/functions/kibana/docs/sample.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions docs/reference/esql/functions/types/to_ip.asciidoc

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions docs/reference/esql/processing-commands/sample.asciidoc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[discrete]
[[esql-sample]]
=== `SAMPLE`

preview::[]

The `SAMPLE` command samples a fraction of the table rows.

**Syntax**

[source,esql]
----
SAMPLE probability
----

*Parameters*

`probability`::
The probability that a row is included in the sample. The value must be between 0 and 1, exclusive.

*Example*

[source.merge.styled,esql]
----
include::{esql-specs}/sample.csv-spec[tag=sampleForDocs]
----
[%header.monospaced.styled,format=dsv,separator=|]
|===
include::{esql-specs}/sample.csv-spec[tag=sampleForDocs-result]
|===
3 changes: 2 additions & 1 deletion docs/reference/rest-api/usage.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ GET /_xpack/usage
"lookup_join" : 0,
"change_point" : 0,
"completion": 0,
"rerank": 0
"rerank": 0,
"sample": 0
},
"queries" : {
"rest" : {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ static TransportVersion def(int id) {
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_8_19 = def(8_841_0_51);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 = def(8_841_0_52);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 = def(8_841_0_53);
public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER_8_19 = def(8_841_0_56);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
import org.elasticsearch.search.aggregations.bucket.sampler.UnmappedSampler;
import org.elasticsearch.search.aggregations.bucket.sampler.random.InternalRandomSampler;
import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplerAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQueryBuilder;
import org.elasticsearch.search.aggregations.bucket.terms.DoubleTerms;
import org.elasticsearch.search.aggregations.bucket.terms.LongRareTerms;
import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
Expand Down Expand Up @@ -1209,6 +1210,9 @@ private void registerQueryParsers(List<SearchPlugin> plugins) {
registerQuery(new QuerySpec<>(ExactKnnQueryBuilder.NAME, ExactKnnQueryBuilder::new, parser -> {
throw new IllegalArgumentException("[exact_knn] queries cannot be provided directly");
}));
registerQuery(
new QuerySpec<>(RandomSamplingQueryBuilder.NAME, RandomSamplingQueryBuilder::new, RandomSamplingQueryBuilder::fromXContent)
);

registerFromPlugin(plugins, SearchPlugin::getQueries, this::registerQuery);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,34 @@ public final class RandomSamplingQuery extends Query {
* can be generated
*/
public RandomSamplingQuery(double p, int seed, int hash) {
if (p <= 0.0 || p >= 1.0) {
throw new IllegalArgumentException("RandomSampling probability must be between 0.0 and 1.0, was [" + p + "]");
}
checkProbabilityRange(p);
this.p = p;
this.seed = seed;
this.hash = hash;
}

/**
* Verifies that the probability is within the (0.0, 1.0) range.
* @throws IllegalArgumentException in case of an invalid probability.
*/
public static void checkProbabilityRange(double p) throws IllegalArgumentException {
if (p <= 0.0 || p >= 1.0) {
throw new IllegalArgumentException("RandomSampling probability must be strictly between 0.0 and 1.0, was [" + p + "]");
}
}

public double probability() {
return p;
}

public int seed() {
return seed;
}

public int hash() {
return hash;
}

@Override
public String toString(String field) {
return "RandomSamplingQuery{" + "p=" + p + ", seed=" + seed + ", hash=" + hash + '}';
Expand Down Expand Up @@ -97,13 +117,13 @@ public void visit(QueryVisitor visitor) {
/**
* A DocIDSetIter that skips a geometrically random number of documents
*/
static class RandomSamplingIterator extends DocIdSetIterator {
public static class RandomSamplingIterator extends DocIdSetIterator {
private final int maxDoc;
private final double p;
private final FastGeometric distribution;
private int doc = -1;

RandomSamplingIterator(int maxDoc, double p, IntSupplier rng) {
public RandomSamplingIterator(int maxDoc, double p, IntSupplier rng) {
this.maxDoc = maxDoc;
this.p = p;
this.distribution = new FastGeometric(rng, p);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.aggregations.bucket.sampler.random;

import org.apache.lucene.search.Query;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.Objects;

import static org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQuery.checkProbabilityRange;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

public class RandomSamplingQueryBuilder extends AbstractQueryBuilder<RandomSamplingQueryBuilder> {

public static final String NAME = "random_sampling";
static final ParseField PROBABILITY = new ParseField("query");
static final ParseField SEED = new ParseField("seed");
static final ParseField HASH = new ParseField("hash");

private final double probability;
private int seed = Randomness.get().nextInt();
private int hash = 0;

public RandomSamplingQueryBuilder(double probability) {
checkProbabilityRange(probability);
this.probability = probability;
}

public RandomSamplingQueryBuilder seed(int seed) {
checkProbabilityRange(probability);
this.seed = seed;
return this;
}

public RandomSamplingQueryBuilder(StreamInput in) throws IOException {
super(in);
this.probability = in.readDouble();
this.seed = in.readInt();
this.hash = in.readInt();
}

public RandomSamplingQueryBuilder hash(Integer hash) {
this.hash = hash;
return this;
}

public double probability() {
return probability;
}

public int seed() {
return seed;
}

public int hash() {
return hash;
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeDouble(probability);
out.writeInt(seed);
out.writeInt(hash);
}

@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
builder.field(PROBABILITY.getPreferredName(), probability);
builder.field(SEED.getPreferredName(), seed);
builder.field(HASH.getPreferredName(), hash);
builder.endObject();
}

private static final ConstructingObjectParser<RandomSamplingQueryBuilder, Void> PARSER = new ConstructingObjectParser<>(
NAME,
false,
args -> {
var randomSamplingQueryBuilder = new RandomSamplingQueryBuilder((double) args[0]);
if (args[1] != null) {
randomSamplingQueryBuilder.seed((int) args[1]);
}
if (args[2] != null) {
randomSamplingQueryBuilder.hash((int) args[2]);
}
return randomSamplingQueryBuilder;
}
);

static {
PARSER.declareDouble(constructorArg(), PROBABILITY);
PARSER.declareInt(optionalConstructorArg(), SEED);
PARSER.declareInt(optionalConstructorArg(), HASH);
}

public static RandomSamplingQueryBuilder fromXContent(XContentParser parser) throws IOException {
return PARSER.apply(parser, null);
}

@Override
protected Query doToQuery(SearchExecutionContext context) throws IOException {
return new RandomSamplingQuery(probability, seed, hash);
}

@Override
protected boolean doEquals(RandomSamplingQueryBuilder other) {
return probability == other.probability && seed == other.seed && hash == other.hash;
}

@Override
protected int doHashCode() {
return Objects.hash(probability, seed, hash);
}

/**
* Returns the name of the writeable object
*/
@Override
public String getWriteableName() {
return NAME;
}

/**
* The minimal version of the recipient this object can be sent to
*/
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.RANDOM_SAMPLER_QUERY_BUILDER_8_19;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ public CheckedBiConsumer<ShardSearchRequest, StreamOutput, IOException> getReque
"range",
"regexp",
"knn_score_doc",
"random_sampling",
"script",
"script_score",
"simple_query_string",
Expand Down
Loading