diff --git a/server/src/main/java/org/elasticsearch/ingest/IngestService.java b/server/src/main/java/org/elasticsearch/ingest/IngestService.java index 06ae43183b368..82be81eed7c2f 100644 --- a/server/src/main/java/org/elasticsearch/ingest/IngestService.java +++ b/server/src/main/java/org/elasticsearch/ingest/IngestService.java @@ -52,7 +52,6 @@ import org.elasticsearch.cluster.service.MasterServiceTaskQueue; import org.elasticsearch.common.Priority; import org.elasticsearch.common.TriConsumer; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; @@ -1391,45 +1390,11 @@ private void attemptToSampleData( * We need both the original document and the fully updated document for sampling, so we make a copy of the original * before overwriting it here. We can discard it after sampling. */ - samplingService.maybeSample(projectMetadata, indexRequest.index(), () -> { - IndexRequest original = copyIndexRequestForSampling(indexRequest); - updateIndexRequestMetadata(original, originalDocumentMetadata); - return original; - }, ingestDocument); + samplingService.maybeSample(projectMetadata, originalDocumentMetadata.getIndex(), indexRequest, ingestDocument); } } - /** - * Creates a copy of an IndexRequest to be used by random sampling. - * @param original The IndexRequest to be copied - * @return A copy of the IndexRequest - */ - private IndexRequest copyIndexRequestForSampling(IndexRequest original) { - IndexRequest clonedRequest = new IndexRequest(original.index()); - clonedRequest.id(original.id()); - clonedRequest.routing(original.routing()); - clonedRequest.version(original.version()); - clonedRequest.versionType(original.versionType()); - clonedRequest.setPipeline(original.getPipeline()); - clonedRequest.setFinalPipeline(original.getFinalPipeline()); - clonedRequest.setIfSeqNo(original.ifSeqNo()); - clonedRequest.setIfPrimaryTerm(original.ifPrimaryTerm()); - clonedRequest.setRefreshPolicy(original.getRefreshPolicy()); - clonedRequest.waitForActiveShards(original.waitForActiveShards()); - clonedRequest.timeout(original.timeout()); - clonedRequest.opType(original.opType()); - clonedRequest.setParentTask(original.getParentTask()); - clonedRequest.setRequireDataStream(original.isRequireDataStream()); - clonedRequest.setRequireAlias(original.isRequireAlias()); - clonedRequest.setIncludeSourceOnError(original.getIncludeSourceOnError()); - BytesReference source = original.source(); - if (source != null) { - clonedRequest.source(source, original.getContentType()); - } - return clonedRequest; - } - private static void executePipeline( final IngestDocument ingestDocument, final Pipeline pipeline, diff --git a/server/src/main/java/org/elasticsearch/ingest/SamplingService.java b/server/src/main/java/org/elasticsearch/ingest/SamplingService.java index 477ef12a5c042..21504f8105198 100644 --- a/server/src/main/java/org/elasticsearch/ingest/SamplingService.java +++ b/server/src/main/java/org/elasticsearch/ingest/SamplingService.java @@ -9,27 +9,76 @@ package org.elasticsearch.ingest; +import org.elasticsearch.action.admin.indices.sampling.SamplingConfiguration; +import org.elasticsearch.action.admin.indices.sampling.SamplingMetadata; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.ProjectMetadata; +import org.elasticsearch.cluster.project.ProjectResolver; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.FeatureFlag; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; +import org.elasticsearch.script.IngestConditionalScript; +import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptService; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.json.JsonXContent; +import java.io.IOException; +import java.lang.ref.SoftReference; +import java.util.Arrays; +import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.LongAdder; +import java.util.function.LongSupplier; import java.util.function.Supplier; public class SamplingService implements ClusterStateListener { + public static final boolean RANDOM_SAMPLING_FEATURE_FLAG = new FeatureFlag("random_sampling").isEnabled(); private static final Logger logger = LogManager.getLogger(SamplingService.class); private final ScriptService scriptService; private final ClusterService clusterService; + private final ProjectResolver projectResolver; + private final LongSupplier relativeMillisTimeSupplier; + private final LongSupplier statsTimeSupplier = System::nanoTime; + private final Random random; + /* + * This Map contains the samples that exist on this node. They are not persisted to disk. They are stored as SoftReferences so that + * sampling does not contribute to a node running out of memory. The idea is that access to samples is desirable, but not critical. We + * make a best effort to keep them around, but do not worry about the complexity or cost of making them durable. + */ + private final Map> samples = new ConcurrentHashMap<>(); - public SamplingService(ScriptService scriptService, ClusterService clusterService) { + public SamplingService( + ScriptService scriptService, + ClusterService clusterService, + ProjectResolver projectResolver, + LongSupplier relativeMillisTimeSupplier + ) { this.scriptService = scriptService; this.clusterService = clusterService; + this.projectResolver = projectResolver; + this.relativeMillisTimeSupplier = relativeMillisTimeSupplier; + random = Randomness.get(); } /** @@ -38,7 +87,13 @@ public SamplingService(ScriptService scriptService, ClusterService clusterServic * @param indexRequest The raw request to potentially sample */ public void maybeSample(ProjectMetadata projectMetadata, IndexRequest indexRequest) { - maybeSample(projectMetadata, indexRequest.index(), () -> indexRequest, () -> { + maybeSample(projectMetadata, indexRequest.index(), indexRequest, () -> { + /* + * The conditional scripts used by random sampling work off of IngestDocuments, in the same way conditionals do in pipelines. In + * this case, we did not have an IngestDocument (which happens when there are no pipelines). So we construct one with the same + * fields as this IndexRequest for use in conditionals. It is created in this lambda to avoid the expensive sourceAsMap call + * if the condition is never executed. + */ Map sourceAsMap; try { sourceAsMap = indexRequest.sourceAsMap(); @@ -58,31 +113,139 @@ public void maybeSample(ProjectMetadata projectMetadata, IndexRequest indexReque } /** - * + * Potentially samples the given indexRequest, depending on the existing sampling configuration. * @param projectMetadata Used to get the sampling configuration - * @param indexRequestSupplier A supplier for the raw request to potentially sample + * @param indexRequest The raw request to potentially sample * @param ingestDocument The IngestDocument used for evaluating any conditionals that are part of the sample configuration */ - public void maybeSample( - ProjectMetadata projectMetadata, - String indexName, - Supplier indexRequestSupplier, - IngestDocument ingestDocument - ) { - maybeSample(projectMetadata, indexName, indexRequestSupplier, () -> ingestDocument); + public void maybeSample(ProjectMetadata projectMetadata, String indexName, IndexRequest indexRequest, IngestDocument ingestDocument) { + maybeSample(projectMetadata, indexName, indexRequest, () -> ingestDocument); } private void maybeSample( ProjectMetadata projectMetadata, String indexName, - Supplier indexRequest, + IndexRequest indexRequest, Supplier ingestDocumentSupplier ) { - // TODO Sampling logic to go here in the near future + if (RANDOM_SAMPLING_FEATURE_FLAG == false) { + return; + } + long startTime = statsTimeSupplier.getAsLong(); + SamplingMetadata samplingMetadata = projectMetadata.custom(SamplingMetadata.TYPE); + if (samplingMetadata == null) { + return; + } + SamplingConfiguration samplingConfig = samplingMetadata.getIndexToSamplingConfigMap().get(indexName); + ProjectId projectId = projectMetadata.id(); + if (samplingConfig == null) { + return; + } + SoftReference sampleInfoReference = samples.compute( + new ProjectIndex(projectId, indexName), + (k, v) -> v == null || v.get() == null + ? new SoftReference<>( + new SampleInfo(samplingConfig.maxSamples(), samplingConfig.timeToLive(), relativeMillisTimeSupplier.getAsLong()) + ) + : v + ); + SampleInfo sampleInfo = sampleInfoReference.get(); + if (sampleInfo == null) { + return; + } + SampleStats stats = sampleInfo.stats; + stats.potentialSamples.increment(); + try { + if (sampleInfo.hasCapacity() == false) { + stats.samplesRejectedForMaxSamplesExceeded.increment(); + return; + } + if (random.nextDouble() >= samplingConfig.rate()) { + stats.samplesRejectedForRate.increment(); + return; + } + String condition = samplingConfig.condition(); + if (condition != null) { + if (sampleInfo.script == null || sampleInfo.factory == null) { + // We don't want to pay for synchronization because worst case, we compile the script twice + long compileScriptStartTime = statsTimeSupplier.getAsLong(); + try { + if (sampleInfo.compilationFailed) { + // we don't want to waste time -- if the script failed to compile once it will just fail again + stats.samplesRejectedForException.increment(); + return; + } else { + Script script = getScript(condition); + sampleInfo.setScript(script, scriptService.compile(script, IngestConditionalScript.CONTEXT)); + } + } catch (Exception e) { + sampleInfo.compilationFailed = true; + throw e; + } finally { + stats.timeCompilingCondition.add((statsTimeSupplier.getAsLong() - compileScriptStartTime)); + } + } + } + if (condition != null + && evaluateCondition(ingestDocumentSupplier, sampleInfo.script, sampleInfo.factory, sampleInfo.stats) == false) { + stats.samplesRejectedForCondition.increment(); + return; + } + RawDocument sample = getRawDocumentForIndexRequest(projectId, indexName, indexRequest); + if (sampleInfo.offer(sample)) { + stats.samples.increment(); + logger.trace("Sampling " + indexRequest); + } else { + stats.samplesRejectedForMaxSamplesExceeded.increment(); + } + } catch (Exception e) { + stats.samplesRejectedForException.increment(); + /* + * We potentially overwrite a previous exception here. But the thinking is that the user will pretty rapidly iterate on + * exceptions as they come up, and this avoids the overhead and complexity of keeping track of multiple exceptions. + */ + stats.lastException = e; + logger.debug("Error performing sampling for " + indexName, e); + } finally { + stats.timeSampling.add((statsTimeSupplier.getAsLong() - startTime)); + } + } + + /** + * Gets the sample for the given projectId and index on this node only. The sample is not persistent. + * @param projectId The project that this sample is for + * @param index The index that the sample is for + * @return The raw documents in the sample on this node, or an empty list if there are none + */ + public List getLocalSample(ProjectId projectId, String index) { + SoftReference sampleInfoReference = samples.get(new ProjectIndex(projectId, index)); + SampleInfo sampleInfo = sampleInfoReference == null ? null : sampleInfoReference.get(); + return sampleInfo == null ? List.of() : Arrays.stream(sampleInfo.getRawDocuments()).filter(Objects::nonNull).toList(); + } + + /** + * Gets the sample stats for the given projectId and index on this node only. The stats are not persistent. They are reset when the + * node restarts for example. + * @param projectId The project that this sample is for + * @param index The index that the sample is for + * @return Current stats on this node for this sample + */ + public SampleStats getLocalSampleStats(ProjectId projectId, String index) { + SoftReference sampleInfoReference = samples.get(new ProjectIndex(projectId, index)); + SampleInfo sampleInfo = sampleInfoReference.get(); + return sampleInfo == null ? new SampleStats() : sampleInfo.stats; } public boolean atLeastOneSampleConfigured() { - return false; // TODO Return true if there is at least one sample in the cluster state + if (RANDOM_SAMPLING_FEATURE_FLAG) { + SamplingMetadata samplingMetadata = clusterService.state() + .projectState(projectResolver.getProjectId()) + .metadata() + .custom(SamplingMetadata.TYPE); + return samplingMetadata != null && samplingMetadata.getIndexToSamplingConfigMap().isEmpty() == false; + } else { + return false; + } } @Override @@ -90,4 +253,369 @@ public void clusterChanged(ClusterChangedEvent event) { // TODO: React to sampling config changes } + private boolean evaluateCondition( + Supplier ingestDocumentSupplier, + Script script, + IngestConditionalScript.Factory factory, + SampleStats stats + ) { + long conditionStartTime = statsTimeSupplier.getAsLong(); + boolean passedCondition = factory.newInstance(script.getParams(), ingestDocumentSupplier.get().getUnmodifiableSourceAndMetadata()) + .execute(); + stats.timeEvaluatingCondition.add((statsTimeSupplier.getAsLong() - conditionStartTime)); + return passedCondition; + } + + private static Script getScript(String conditional) throws IOException { + logger.debug("Parsing script for conditional " + conditional); + try ( + XContentBuilder builder = XContentBuilder.builder(JsonXContent.jsonXContent).map(Map.of("source", conditional)); + XContentParser parser = XContentHelper.createParserNotCompressed( + LoggingDeprecationHandler.XCONTENT_PARSER_CONFIG, + BytesReference.bytes(builder), + XContentType.JSON + ) + ) { + return Script.parse(parser); + } + } + + /* + * This represents a raw document as the user sent it to us in an IndexRequest. It only holds onto the information needed for the + * sampling API, rather than holding all of the fields a user might send in an IndexRequest. + */ + public record RawDocument(ProjectId projectId, String indexName, byte[] source, XContentType contentType) implements Writeable { + + public RawDocument(StreamInput in) throws IOException { + this(ProjectId.readFrom(in), in.readString(), in.readByteArray(), in.readEnum(XContentType.class)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + projectId.writeTo(out); + out.writeString(indexName); + out.writeByteArray(source); + XContentHelper.writeTo(out, contentType); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RawDocument rawDocument = (RawDocument) o; + return Objects.equals(projectId, rawDocument.projectId) + && Objects.equals(indexName, rawDocument.indexName) + && Arrays.equals(source, rawDocument.source) + && contentType == rawDocument.contentType; + } + + @Override + public int hashCode() { + int result = Objects.hash(projectId, indexName, contentType); + result = 31 * result + Arrays.hashCode(source); + return result; + } + } + + /* + * This creates a RawDocument from the indexRequest. The source bytes of the indexRequest are copied into the RawDocument. So the + * RawDocument might be a relatively expensive object memory-wise. Since the bytes are copied, subsequent changes to the indexRequest + * are not reflected in the RawDocument + */ + private RawDocument getRawDocumentForIndexRequest(ProjectId projectId, String indexName, IndexRequest indexRequest) { + BytesReference sourceReference = indexRequest.source(); + assert sourceReference != null : "Cannot sample an IndexRequest with no source"; + byte[] source = sourceReference.array(); + final byte[] sourceCopy = new byte[sourceReference.length()]; + System.arraycopy(source, sourceReference.arrayOffset(), sourceCopy, 0, sourceReference.length()); + return new RawDocument(projectId, indexName, sourceCopy, indexRequest.getContentType()); + } + + public static final class SampleStats implements Writeable, ToXContent { + // These are all non-private for the sake of unit testing + final LongAdder samples = new LongAdder(); + final LongAdder potentialSamples = new LongAdder(); + final LongAdder samplesRejectedForMaxSamplesExceeded = new LongAdder(); + final LongAdder samplesRejectedForCondition = new LongAdder(); + final LongAdder samplesRejectedForRate = new LongAdder(); + final LongAdder samplesRejectedForException = new LongAdder(); + final LongAdder timeSampling = new LongAdder(); + final LongAdder timeEvaluatingCondition = new LongAdder(); + final LongAdder timeCompilingCondition = new LongAdder(); + Exception lastException = null; + + public SampleStats() {} + + public SampleStats(StreamInput in) throws IOException { + potentialSamples.add(in.readLong()); + samplesRejectedForMaxSamplesExceeded.add(in.readLong()); + samplesRejectedForCondition.add(in.readLong()); + samplesRejectedForRate.add(in.readLong()); + samplesRejectedForException.add(in.readLong()); + samples.add(in.readLong()); + timeSampling.add(in.readLong()); + timeEvaluatingCondition.add(in.readLong()); + timeCompilingCondition.add(in.readLong()); + if (in.readBoolean()) { + lastException = in.readException(); + } else { + lastException = null; + } + } + + public long getSamples() { + return samples.longValue(); + } + + public long getPotentialSamples() { + return potentialSamples.longValue(); + } + + public long getSamplesRejectedForMaxSamplesExceeded() { + return samplesRejectedForMaxSamplesExceeded.longValue(); + } + + public long getSamplesRejectedForCondition() { + return samplesRejectedForCondition.longValue(); + } + + public long getSamplesRejectedForRate() { + return samplesRejectedForRate.longValue(); + } + + public long getSamplesRejectedForException() { + return samplesRejectedForException.longValue(); + } + + public TimeValue getTimeSampling() { + return TimeValue.timeValueNanos(timeSampling.longValue()); + } + + public TimeValue getTimeEvaluatingCondition() { + return TimeValue.timeValueNanos(timeEvaluatingCondition.longValue()); + } + + public TimeValue getTimeCompilingCondition() { + return TimeValue.timeValueNanos(timeCompilingCondition.longValue()); + } + + public Exception getLastException() { + return lastException; + } + + @Override + public String toString() { + return "potential_samples: " + + potentialSamples + + ", samples_rejected_for_max_samples_exceeded: " + + samplesRejectedForMaxSamplesExceeded + + ", samples_rejected_for_condition: " + + samplesRejectedForCondition + + ", samples_rejected_for_rate: " + + samplesRejectedForRate + + ", samples_rejected_for_exception: " + + samplesRejectedForException + + ", samples_accepted: " + + samples + + ", time_sampling: " + + (timeSampling.longValue() / 1000000) + + ", time_evaluating_condition: " + + (timeEvaluatingCondition.longValue() / 1000000) + + ", time_compiling_condition: " + + (timeCompilingCondition.longValue() / 1000000); + } + + public SampleStats combine(SampleStats other) { + SampleStats result = new SampleStats(); + addAllFields(this, result); + addAllFields(other, result); + return result; + } + + private static void addAllFields(SampleStats source, SampleStats dest) { + dest.potentialSamples.add(source.potentialSamples.longValue()); + dest.samplesRejectedForMaxSamplesExceeded.add(source.samplesRejectedForMaxSamplesExceeded.longValue()); + dest.samplesRejectedForCondition.add(source.samplesRejectedForCondition.longValue()); + dest.samplesRejectedForRate.add(source.samplesRejectedForRate.longValue()); + dest.samplesRejectedForException.add(source.samplesRejectedForException.longValue()); + dest.samples.add(source.samples.longValue()); + dest.timeSampling.add(source.timeSampling.longValue()); + dest.timeEvaluatingCondition.add(source.timeEvaluatingCondition.longValue()); + dest.timeCompilingCondition.add(source.timeCompilingCondition.longValue()); + if (dest.lastException == null) { + dest.lastException = source.lastException; + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field("potential_samples", potentialSamples.longValue()); + builder.field("samples_rejected_for_max_samples_exceeded", samplesRejectedForMaxSamplesExceeded.longValue()); + builder.field("samples_rejected_for_condition", samplesRejectedForCondition.longValue()); + builder.field("samples_rejected_for_rate", samplesRejectedForRate.longValue()); + builder.field("samples_rejected_for_exception", samplesRejectedForException.longValue()); + builder.field("samples_accepted", samples.longValue()); + builder.field("time_sampling", (timeSampling.longValue() / 1000000)); + builder.field("time_evaluating_condition", (timeEvaluatingCondition.longValue() / 1000000)); + builder.field("time_compiling_condition", (timeCompilingCondition.longValue() / 1000000)); + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeLong(potentialSamples.longValue()); + out.writeLong(samplesRejectedForMaxSamplesExceeded.longValue()); + out.writeLong(samplesRejectedForCondition.longValue()); + out.writeLong(samplesRejectedForRate.longValue()); + out.writeLong(samplesRejectedForException.longValue()); + out.writeLong(samples.longValue()); + out.writeLong(timeSampling.longValue()); + out.writeLong(timeEvaluatingCondition.longValue()); + out.writeLong(timeCompilingCondition.longValue()); + if (lastException == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeException(lastException); + } + } + + /* + * equals and hashCode are implemented for the sake of testing serialization. Since this class is mutable, these ought to never be + * used outside of testing. + */ + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SampleStats that = (SampleStats) o; + if (samples.longValue() != that.samples.longValue()) { + return false; + } + if (potentialSamples.longValue() != that.potentialSamples.longValue()) { + return false; + } + if (samplesRejectedForMaxSamplesExceeded.longValue() != that.samplesRejectedForMaxSamplesExceeded.longValue()) { + return false; + } + if (samplesRejectedForCondition.longValue() != that.samplesRejectedForCondition.longValue()) { + return false; + } + if (samplesRejectedForRate.longValue() != that.samplesRejectedForRate.longValue()) { + return false; + } + if (samplesRejectedForException.longValue() != that.samplesRejectedForException.longValue()) { + return false; + } + if (timeSampling.longValue() != that.timeSampling.longValue()) { + return false; + } + if (timeEvaluatingCondition.longValue() != that.timeEvaluatingCondition.longValue()) { + return false; + } + if (timeCompilingCondition.longValue() != that.timeCompilingCondition.longValue()) { + return false; + } + return exceptionsAreEqual(lastException, that.lastException); + } + + private boolean exceptionsAreEqual(Exception e1, Exception e2) { + if (e1 == null && e2 == null) { + return true; + } + if (e1 == null || e2 == null) { + return false; + } + return e1.getClass().equals(e2.getClass()) && e1.getMessage().equals(e2.getMessage()); + } + + @Override + public int hashCode() { + return Objects.hash( + samples.longValue(), + potentialSamples.longValue(), + samplesRejectedForMaxSamplesExceeded.longValue(), + samplesRejectedForCondition.longValue(), + samplesRejectedForRate.longValue(), + samplesRejectedForException.longValue(), + timeSampling.longValue(), + timeEvaluatingCondition.longValue(), + timeCompilingCondition.longValue() + ) + hashException(lastException); + } + + private int hashException(Exception e) { + if (e == null) { + return 0; + } else { + return Objects.hash(e.getClass(), e.getMessage()); + } + } + } + + /* + * This is used internally to store information about a sample in the samples Map. + */ + private static final class SampleInfo { + private final RawDocument[] rawDocuments; + private final SampleStats stats; + private final long expiration; + private final TimeValue timeToLive; + private volatile Script script; + private volatile IngestConditionalScript.Factory factory; + private volatile boolean compilationFailed = false; + private volatile boolean isFull = false; + private final AtomicInteger arrayIndex = new AtomicInteger(0); + + SampleInfo(int maxSamples, TimeValue timeToLive, long relativeNowMillis) { + this.timeToLive = timeToLive; + this.rawDocuments = new RawDocument[maxSamples]; + this.stats = new SampleStats(); + this.expiration = (timeToLive == null ? TimeValue.timeValueDays(5).millis() : timeToLive.millis()) + relativeNowMillis; + } + + public boolean hasCapacity() { + return isFull == false; + } + + /* + * This returns the array of raw documents. It's size will be the maximum number of raw documents allowed in this sample. Some (or + * all) elements could be null. + */ + public RawDocument[] getRawDocuments() { + return rawDocuments; + } + + /* + * Adds the rawDocument to the sample if there is capacity. Returns true if it adds it, or false if it does not. + */ + public boolean offer(RawDocument rawDocument) { + int index = arrayIndex.getAndIncrement(); + if (index < rawDocuments.length) { + rawDocuments[index] = rawDocument; + if (index == rawDocuments.length - 1) { + isFull = true; + } + return true; + } + return false; + } + + void setScript(Script script, IngestConditionalScript.Factory factory) { + this.script = script; + this.factory = factory; + } + } + + /* + * This is meant to be used internally as the key of the map of samples + */ + private record ProjectIndex(ProjectId projectId, String indexName) {}; + } diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index 7a598475fc456..f1d5b6cc804bf 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -726,7 +726,12 @@ private void construct( FeatureService featureService = new FeatureService(pluginsService.loadServiceProviders(FeatureSpecification.class)); - SamplingService samplingService = new SamplingService(scriptService, clusterService); + SamplingService samplingService = new SamplingService( + scriptService, + clusterService, + projectResolver, + threadPool.relativeTimeInMillisSupplier() + ); modules.bindToInstance(SamplingService.class, samplingService); clusterService.addListener(samplingService); diff --git a/server/src/test/java/org/elasticsearch/ingest/SamplingServiceRawDocumentTests.java b/server/src/test/java/org/elasticsearch/ingest/SamplingServiceRawDocumentTests.java new file mode 100644 index 0000000000000..3734d2dcc3f7b --- /dev/null +++ b/server/src/test/java/org/elasticsearch/ingest/SamplingServiceRawDocumentTests.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.ingest; + +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.SamplingService.RawDocument; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; + +public class SamplingServiceRawDocumentTests extends AbstractWireSerializingTestCase { + @Override + protected Writeable.Reader instanceReader() { + return RawDocument::new; + } + + @Override + protected RawDocument createTestInstance() { + return new RawDocument( + randomProjectIdOrDefault(), + randomIdentifier(), + randomByteArrayOfLength(randomIntBetween(10, 1000)), + randomFrom(XContentType.values()) + ); + } + + @Override + protected RawDocument mutateInstance(RawDocument instance) throws IOException { + ProjectId projectId = instance.projectId(); + String indexName = instance.indexName(); + byte[] source = instance.source(); + XContentType xContentType = instance.contentType(); + + switch (between(0, 3)) { + case 0 -> projectId = randomValueOtherThan(projectId, ESTestCase::randomProjectIdOrDefault); + case 1 -> indexName = randomValueOtherThan(indexName, ESTestCase::randomIdentifier); + case 2 -> source = randomByteArrayOfLength(randomIntBetween(100, 1000)); + case 3 -> xContentType = randomValueOtherThan(xContentType, () -> randomFrom(XContentType.values())); + default -> throw new IllegalArgumentException("Should never get here"); + } + return new RawDocument(projectId, indexName, source, xContentType); + } +} diff --git a/server/src/test/java/org/elasticsearch/ingest/SamplingServiceSampleStatsTests.java b/server/src/test/java/org/elasticsearch/ingest/SamplingServiceSampleStatsTests.java new file mode 100644 index 0000000000000..df8973d27ce3a --- /dev/null +++ b/server/src/test/java/org/elasticsearch/ingest/SamplingServiceSampleStatsTests.java @@ -0,0 +1,120 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.ingest; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.ingest.SamplingService.SampleStats; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertNotSame; + +public class SamplingServiceSampleStatsTests extends AbstractWireSerializingTestCase { + + @Override + protected Writeable.Reader instanceReader() { + return SampleStats::new; + } + + @Override + protected SampleStats createTestInstance() { + SampleStats stats = new SampleStats(); + stats.samples.add(randomReasonableLong()); + stats.potentialSamples.add(randomReasonableLong()); + stats.samplesRejectedForMaxSamplesExceeded.add(randomReasonableLong()); + stats.samplesRejectedForCondition.add(randomReasonableLong()); + stats.samplesRejectedForRate.add(randomReasonableLong()); + stats.samplesRejectedForException.add(randomReasonableLong()); + stats.timeSampling.add(randomReasonableLong()); + stats.timeEvaluatingCondition.add(randomReasonableLong()); + stats.timeCompilingCondition.add(randomReasonableLong()); + stats.lastException = randomBoolean() ? null : new ElasticsearchException(randomAlphanumericOfLength(10)); + return stats; + } + + /* + * This is to avoid overflow errors in these tests. + */ + private long randomReasonableLong() { + long randomLong = randomNonNegativeLong(); + if (randomLong > Long.MAX_VALUE / 2) { + return randomLong / 2; + } else { + return randomLong; + } + } + + @Override + protected SampleStats mutateInstance(SampleStats instance) throws IOException { + SampleStats mutated = instance.combine(new SampleStats()); + switch (between(0, 9)) { + case 0 -> mutated.samples.add(1); + case 1 -> mutated.potentialSamples.add(1); + case 2 -> mutated.samplesRejectedForMaxSamplesExceeded.add(1); + case 3 -> mutated.samplesRejectedForCondition.add(1); + case 4 -> mutated.samplesRejectedForRate.add(1); + case 5 -> mutated.samplesRejectedForException.add(1); + case 6 -> mutated.timeSampling.add(1); + case 7 -> mutated.timeEvaluatingCondition.add(1); + case 8 -> mutated.timeCompilingCondition.add(1); + case 9 -> mutated.lastException = mutated.lastException == null + ? new ElasticsearchException(randomAlphanumericOfLength(10)) + : null; + default -> throw new IllegalArgumentException("Should never get here"); + } + return mutated; + } + + public void testCombine() { + SampleStats stats1 = createTestInstance(); + stats1.lastException = null; + SampleStats combinedWithEmpty = stats1.combine(new SampleStats()); + assertThat(combinedWithEmpty, equalTo(stats1)); + assertNotSame(stats1, combinedWithEmpty); + SampleStats stats2 = createTestInstance(); + SampleStats stats1CombineStats2 = stats1.combine(stats2); + SampleStats stats2CombineStats1 = stats2.combine(stats1); + assertThat(stats1CombineStats2, equalTo(stats2CombineStats1)); + assertThat(stats1CombineStats2.getSamples(), equalTo(stats1.getSamples() + stats2.getSamples())); + assertThat(stats1CombineStats2.getPotentialSamples(), equalTo(stats1.getPotentialSamples() + stats2.getPotentialSamples())); + assertThat( + stats1CombineStats2.getSamplesRejectedForMaxSamplesExceeded(), + equalTo(stats1.getSamplesRejectedForMaxSamplesExceeded() + stats2.getSamplesRejectedForMaxSamplesExceeded()) + ); + assertThat( + stats1CombineStats2.getSamplesRejectedForCondition(), + equalTo(stats1.getSamplesRejectedForCondition() + stats2.getSamplesRejectedForCondition()) + ); + assertThat( + stats1CombineStats2.getSamplesRejectedForRate(), + equalTo(stats1.getSamplesRejectedForRate() + stats2.getSamplesRejectedForRate()) + ); + assertThat( + stats1CombineStats2.getSamplesRejectedForException(), + equalTo(stats1.getSamplesRejectedForException() + stats2.getSamplesRejectedForException()) + ); + assertThat( + stats1CombineStats2.getTimeSampling(), + equalTo(TimeValue.timeValueNanos(stats1.getTimeSampling().nanos() + stats2.getTimeSampling().nanos())) + ); + assertThat( + stats1CombineStats2.getTimeEvaluatingCondition(), + equalTo(TimeValue.timeValueNanos(stats1.getTimeEvaluatingCondition().nanos() + stats2.getTimeEvaluatingCondition().nanos())) + ); + assertThat( + stats1CombineStats2.getTimeCompilingCondition(), + equalTo(TimeValue.timeValueNanos(stats1.getTimeCompilingCondition().nanos() + stats2.getTimeCompilingCondition().nanos())) + ); + } +} diff --git a/server/src/test/java/org/elasticsearch/ingest/SamplingServiceTests.java b/server/src/test/java/org/elasticsearch/ingest/SamplingServiceTests.java new file mode 100644 index 0000000000000..dc8d4067e95ab --- /dev/null +++ b/server/src/test/java/org/elasticsearch/ingest/SamplingServiceTests.java @@ -0,0 +1,237 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.ingest; + +import org.elasticsearch.action.admin.indices.sampling.SamplingConfiguration; +import org.elasticsearch.action.admin.indices.sampling.SamplingMetadata; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.cluster.metadata.ProjectMetadata; +import org.elasticsearch.cluster.project.ProjectResolver; +import org.elasticsearch.cluster.project.TestProjectResolvers; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.script.MockScriptEngine; +import org.elasticsearch.script.Script; +import org.elasticsearch.script.ScriptModule; +import org.elasticsearch.script.ScriptService; +import org.elasticsearch.test.ClusterServiceUtils; +import org.elasticsearch.test.ESTestCase; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; + +public class SamplingServiceTests extends ESTestCase { + + private static final String TEST_CONDITIONAL_SCRIPT = "ctx?.foo == 'bar'"; + + public void testMaybeSample() { + SamplingService samplingService = getTestSamplingService(); + + // First, test with a project that has no sampling config: + String indexName = randomIdentifier(); + ProjectMetadata.Builder projectBuilder = ProjectMetadata.builder(ProjectId.DEFAULT); + final ProjectId projectId = projectBuilder.getId(); + ProjectMetadata projectMetadata = projectBuilder.build(); + Map inputRawDocSource = randomMap(1, 100, () -> Tuple.tuple(randomAlphaOfLength(10), randomAlphaOfLength(10))); + final IndexRequest indexRequest = new IndexRequest(indexName).id("_id").source(inputRawDocSource); + samplingService.maybeSample(projectMetadata, indexRequest); + assertThat(samplingService.getLocalSample(projectId, indexName), empty()); + + // Now test with a valid configuration: + int maxSize = 100; + projectBuilder = ProjectMetadata.builder(projectMetadata) + .putCustom( + SamplingMetadata.TYPE, + new SamplingMetadata( + Map.of(indexName, new SamplingConfiguration(1.0, maxSize, ByteSizeValue.ofMb(100), TimeValue.timeValueDays(3), null)) + ) + ); + projectMetadata = projectBuilder.build(); + int docsToSample = randomIntBetween(1, maxSize); + for (int i = 0; i < docsToSample; i++) { + samplingService.maybeSample(projectMetadata, indexRequest); + } + List sample = samplingService.getLocalSample(projectId, indexName); + assertThat(sample, not(empty())); + // Since our sampling rate was 100%, we expect every document to have been sampled: + assertThat(sample.size(), equalTo(docsToSample)); + SamplingService.RawDocument rawDocument = sample.getFirst(); + assertThat(rawDocument.indexName(), equalTo(indexName)); + Map outputRawDocSource = XContentHelper.convertToMap( + rawDocument.contentType().xContent(), + rawDocument.source(), + 0, + rawDocument.source().length, + randomBoolean() + ); + assertThat(outputRawDocSource, equalTo(inputRawDocSource)); + + SamplingService.SampleStats stats = samplingService.getLocalSampleStats(projectId, indexName); + assertThat(stats.getSamples(), equalTo((long) docsToSample)); + assertThat(stats.getPotentialSamples(), equalTo((long) docsToSample)); + assertThat(stats.getSamplesRejectedForRate(), equalTo(0L)); + assertThat(stats.getSamplesRejectedForCondition(), equalTo(0L)); + assertThat(stats.getSamplesRejectedForException(), equalTo(0L)); + assertThat(stats.getSamplesRejectedForMaxSamplesExceeded(), equalTo(0L)); + assertThat(stats.getLastException(), nullValue()); + assertThat(stats.getTimeSampling(), greaterThan(TimeValue.ZERO)); + assertThat(stats.getTimeCompilingCondition(), equalTo(TimeValue.ZERO)); + assertThat(stats.getTimeEvaluatingCondition(), equalTo(TimeValue.ZERO)); + } + + public void testMaybeSampleWithCondition() { + SamplingService samplingService = getTestSamplingService(); + String indexName = randomIdentifier(); + ProjectMetadata.Builder projectBuilder = ProjectMetadata.builder(ProjectId.DEFAULT) + .putCustom( + SamplingMetadata.TYPE, + new SamplingMetadata( + Map.of( + indexName, + new SamplingConfiguration(1.0, 100, ByteSizeValue.ofMb(100), TimeValue.timeValueDays(3), TEST_CONDITIONAL_SCRIPT) + ) + ) + ); + final ProjectId projectId = projectBuilder.getId(); + ProjectMetadata projectMetadata = projectBuilder.build(); + Map indexRequest1Source = Map.of("foo", "bar", "baz", "bop"); + final IndexRequest indexRequest1 = new IndexRequest(indexName).id("_id").source(indexRequest1Source); + samplingService.maybeSample(projectMetadata, indexRequest1); + final IndexRequest indexRequest2 = new IndexRequest(indexName).id("_id").source(Map.of("bar", "foo", "baz", "bop")); + samplingService.maybeSample(projectMetadata, indexRequest2); + List sample = samplingService.getLocalSample(projectId, indexName); + assertThat(sample.size(), equalTo(1)); + SamplingService.RawDocument rawDocument = sample.getFirst(); + assertThat(rawDocument.indexName(), equalTo(indexName)); + Map outputRawDocSource = XContentHelper.convertToMap( + rawDocument.contentType().xContent(), + rawDocument.source(), + 0, + rawDocument.source().length, + randomBoolean() + ); + assertThat(outputRawDocSource, equalTo(indexRequest1Source)); + + SamplingService.SampleStats stats = samplingService.getLocalSampleStats(projectId, indexName); + assertThat(stats.getSamples(), equalTo((long) 1)); + assertThat(stats.getPotentialSamples(), equalTo((long) 2)); + assertThat(stats.getSamplesRejectedForRate(), equalTo(0L)); + assertThat(stats.getSamplesRejectedForCondition(), equalTo(1L)); + assertThat(stats.getSamplesRejectedForException(), equalTo(0L)); + assertThat(stats.getSamplesRejectedForMaxSamplesExceeded(), equalTo(0L)); + assertThat(stats.getLastException(), nullValue()); + assertThat(stats.getTimeSampling(), greaterThan(TimeValue.ZERO)); + assertThat(stats.getTimeCompilingCondition(), greaterThan(TimeValue.ZERO)); + assertThat(stats.getTimeEvaluatingCondition(), greaterThan(TimeValue.ZERO)); + } + + public void testMaybeSampleWithLowRate() { + SamplingService samplingService = getTestSamplingService(); + String indexName = randomIdentifier(); + ProjectMetadata.Builder projectBuilder = ProjectMetadata.builder(ProjectId.DEFAULT) + .putCustom( + SamplingMetadata.TYPE, + new SamplingMetadata( + Map.of(indexName, new SamplingConfiguration(0.001, 100, ByteSizeValue.ofMb(100), TimeValue.timeValueDays(3), null)) + ) + ); + final ProjectId projectId = projectBuilder.getId(); + ProjectMetadata projectMetadata = projectBuilder.build(); + Map inputRawDocSource = randomMap(1, 100, () -> Tuple.tuple(randomAlphaOfLength(10), randomAlphaOfLength(10))); + final IndexRequest indexRequest = new IndexRequest(indexName).id("_id").source(inputRawDocSource); + for (int i = 0; i < 100; i++) { + samplingService.maybeSample(projectMetadata, indexRequest); + } + /* + * We had 100 chances to take a sample. We're sampling at a rate of one in 1000. The odds of even one are fairly low. The odds of + * 10 are so low that we will almost certainly never see that unless there is a bug. We're really just making sure that we don't + * introduce a bug where we ignore the rate. + */ + int samples = samplingService.getLocalSample(projectId, indexName).size(); + assertThat(samples, lessThan(10)); + + SamplingService.SampleStats stats = samplingService.getLocalSampleStats(projectId, indexName); + assertThat(stats.getSamples(), equalTo((long) samples)); + assertThat(stats.getPotentialSamples(), equalTo(100L)); + assertThat(stats.getSamplesRejectedForRate(), equalTo((long) 100 - samples)); + assertThat(stats.getSamplesRejectedForCondition(), equalTo(0L)); + assertThat(stats.getSamplesRejectedForException(), equalTo(0L)); + assertThat(stats.getSamplesRejectedForMaxSamplesExceeded(), equalTo(0L)); + assertThat(stats.getLastException(), nullValue()); + assertThat(stats.getTimeSampling(), greaterThan(TimeValue.ZERO)); + assertThat(stats.getTimeCompilingCondition(), equalTo(TimeValue.ZERO)); + assertThat(stats.getTimeEvaluatingCondition(), equalTo(TimeValue.ZERO)); + } + + public void testMaybeSampleMaxSamples() { + SamplingService samplingService = getTestSamplingService(); + String indexName = randomIdentifier(); + int maxSamples = randomIntBetween(1, 1000); + ProjectMetadata.Builder projectBuilder = ProjectMetadata.builder(ProjectId.DEFAULT) + .putCustom( + SamplingMetadata.TYPE, + new SamplingMetadata( + Map.of(indexName, new SamplingConfiguration(1.0, maxSamples, ByteSizeValue.ofMb(100), TimeValue.timeValueDays(3), null)) + ) + ); + final ProjectId projectId = projectBuilder.getId(); + ProjectMetadata projectMetadata = projectBuilder.build(); + Map inputRawDocSource = randomMap(1, 100, () -> Tuple.tuple(randomAlphaOfLength(10), randomAlphaOfLength(10))); + final IndexRequest indexRequest = new IndexRequest(indexName).id("_id").source(inputRawDocSource); + int docsToSample = randomIntBetween(maxSamples + 1, maxSamples + 1000); + for (int i = 0; i < docsToSample; i++) { + samplingService.maybeSample(projectMetadata, indexRequest); + } + assertThat(samplingService.getLocalSample(projectId, indexName).size(), equalTo(maxSamples)); + + SamplingService.SampleStats stats = samplingService.getLocalSampleStats(projectId, indexName); + assertThat(stats.getSamples(), equalTo((long) maxSamples)); + assertThat(stats.getPotentialSamples(), equalTo((long) docsToSample)); + assertThat(stats.getSamplesRejectedForRate(), equalTo(0L)); + assertThat(stats.getSamplesRejectedForCondition(), equalTo(0L)); + assertThat(stats.getSamplesRejectedForException(), equalTo(0L)); + assertThat(stats.getSamplesRejectedForMaxSamplesExceeded(), equalTo((long) docsToSample - maxSamples)); + assertThat(stats.getLastException(), nullValue()); + assertThat(stats.getTimeSampling(), greaterThan(TimeValue.ZERO)); + assertThat(stats.getTimeCompilingCondition(), equalTo(TimeValue.ZERO)); + assertThat(stats.getTimeEvaluatingCondition(), equalTo(TimeValue.ZERO)); + } + + private SamplingService getTestSamplingService() { + final ScriptService scriptService = new ScriptService( + Settings.EMPTY, + Map.of(Script.DEFAULT_SCRIPT_LANG, new MockScriptEngine(Script.DEFAULT_SCRIPT_LANG, Map.of(TEST_CONDITIONAL_SCRIPT, ctx -> { + Object fooVal = ctx.get("foo"); + return fooVal != null && fooVal.equals("bar"); + }), Map.of())), + new HashMap<>(ScriptModule.CORE_CONTEXTS), + () -> 1L, + TestProjectResolvers.singleProject(randomProjectIdOrDefault()) + ); + ClusterService clusterService = ClusterServiceUtils.createClusterService(new DeterministicTaskQueue().getThreadPool()); + final ProjectId projectId = ProjectId.DEFAULT; + final ProjectResolver projectResolver = TestProjectResolvers.singleProject(projectId); + return new SamplingService(scriptService, clusterService, projectResolver, System::currentTimeMillis); + } +}