Skip to content

Commit a38b04c

Browse files
committed
Execute single shard each pipeline for rate
1 parent 593212b commit a38b04c

File tree

3 files changed

+161
-38
lines changed

3 files changed

+161
-38
lines changed

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesIT.java

Lines changed: 110 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,20 @@
77

88
package org.elasticsearch.xpack.esql.action;
99

10+
import org.elasticsearch.cluster.metadata.IndexMetadata;
1011
import org.elasticsearch.common.Randomness;
1112
import org.elasticsearch.common.settings.Settings;
1213
import org.elasticsearch.common.unit.ByteSizeValue;
14+
import org.elasticsearch.common.util.iterable.Iterables;
15+
import org.elasticsearch.compute.lucene.LuceneSourceOperator;
1316
import org.elasticsearch.compute.lucene.TimeSeriesSourceOperator;
1417
import org.elasticsearch.compute.operator.DriverProfile;
1518
import org.elasticsearch.compute.operator.OperatorStatus;
1619
import org.elasticsearch.compute.operator.TimeSeriesAggregationOperator;
1720
import org.elasticsearch.xpack.esql.EsqlTestUtils;
1821
import org.elasticsearch.xpack.esql.core.type.DataType;
22+
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
23+
import org.hamcrest.Matchers;
1924
import org.junit.Before;
2025

2126
import java.util.ArrayList;
@@ -481,33 +486,113 @@ public void testFieldDoesNotExist() {
481486
}
482487
}
483488

484-
public void testRateProfile() {
485-
EsqlQueryRequest request = new EsqlQueryRequest();
486-
request.profile(true);
487-
request.query("TS hosts | STATS sum(rate(request_count)) BY cluster, bucket(@timestamp, 1minute) | SORT cluster");
488-
try (var resp = run(request)) {
489-
EsqlQueryResponse.Profile profile = resp.profile();
490-
List<DriverProfile> dataProfiles = profile.drivers().stream().filter(d -> d.description().equals("data")).toList();
491-
int totalTimeSeries = 0;
492-
for (DriverProfile p : dataProfiles) {
493-
if (p.operators().stream().anyMatch(s -> s.status() instanceof TimeSeriesSourceOperator.Status)) {
494-
totalTimeSeries++;
495-
assertThat(p.operators(), hasSize(2));
496-
assertThat(p.operators().get(1).operator(), equalTo("ExchangeSinkOperator"));
497-
} else if (p.operators().stream().anyMatch(s -> s.status() instanceof TimeSeriesAggregationOperator.Status)) {
498-
assertThat(p.operators(), hasSize(3));
499-
assertThat(p.operators().get(0).operator(), equalTo("ExchangeSourceOperator"));
500-
assertThat(p.operators().get(1).operator(), containsString("TimeSeriesAggregationOperator"));
501-
assertThat(p.operators().get(2).operator(), equalTo("ExchangeSinkOperator"));
502-
} else {
503-
assertThat(p.operators(), hasSize(4));
504-
assertThat(p.operators().get(0).operator(), equalTo("ExchangeSourceOperator"));
505-
assertThat(p.operators().get(1).operator(), containsString("TimeSeriesExtractFieldOperator"));
506-
assertThat(p.operators().get(2).operator(), containsString("EvalOperator"));
507-
assertThat(p.operators().get(3).operator(), equalTo("ExchangeSinkOperator"));
489+
public void testProfile() {
490+
String dataNode = Iterables.get(clusterService().state().getNodes().getDataNodes().keySet(), 0);
491+
Settings indexSettings = Settings.builder()
492+
.put("mode", "time_series")
493+
.putList("routing_path", List.of("host", "cluster"))
494+
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 3)
495+
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
496+
.put("index.routing.allocation.require._id", dataNode)
497+
.build();
498+
String index = "my-hosts";
499+
client().admin()
500+
.indices()
501+
.prepareCreate(index)
502+
.setSettings(indexSettings)
503+
.setMapping(
504+
"@timestamp",
505+
"type=date",
506+
"host",
507+
"type=keyword,time_series_dimension=true",
508+
"cluster",
509+
"type=keyword,time_series_dimension=true",
510+
"memory",
511+
"type=long,time_series_metric=gauge",
512+
"request_count",
513+
"type=integer,time_series_metric=counter"
514+
)
515+
.get();
516+
Randomness.shuffle(docs);
517+
for (Doc doc : docs) {
518+
client().prepareIndex(index)
519+
.setSource(
520+
"@timestamp",
521+
doc.timestamp,
522+
"host",
523+
doc.host,
524+
"cluster",
525+
doc.cluster,
526+
"memory",
527+
doc.memory.getBytes(),
528+
"cpu",
529+
doc.cpu,
530+
"request_count",
531+
doc.requestCount
532+
)
533+
.get();
534+
}
535+
client().admin().indices().prepareRefresh(index).get();
536+
QueryPragmas pragmas = new QueryPragmas(
537+
Settings.builder()
538+
.put(QueryPragmas.MAX_CONCURRENT_SHARDS_PER_NODE.getKey(), between(3, 10))
539+
.put(QueryPragmas.TASK_CONCURRENCY.getKey(), 1)
540+
.build()
541+
);
542+
// The rate aggregation is executed with one shard at a time
543+
{
544+
EsqlQueryRequest request = new EsqlQueryRequest();
545+
request.profile(true);
546+
request.pragmas(pragmas);
547+
request.acceptedPragmaRisks(true);
548+
request.query("TS my-hosts | STATS sum(rate(request_count)) BY cluster, bucket(@timestamp, 1minute) | SORT cluster");
549+
try (var resp = run(request)) {
550+
EsqlQueryResponse.Profile profile = resp.profile();
551+
List<DriverProfile> dataProfiles = profile.drivers().stream().filter(d -> d.description().equals("data")).toList();
552+
for (DriverProfile p : dataProfiles) {
553+
if (p.operators().stream().anyMatch(s -> s.status() instanceof TimeSeriesSourceOperator.Status)) {
554+
assertThat(p.operators(), hasSize(2));
555+
TimeSeriesSourceOperator.Status status = (TimeSeriesSourceOperator.Status) p.operators().get(0).status();
556+
assertThat(status.processedShards(), hasSize(1));
557+
assertThat(p.operators().get(1).operator(), equalTo("ExchangeSinkOperator"));
558+
} else if (p.operators().stream().anyMatch(s -> s.status() instanceof TimeSeriesAggregationOperator.Status)) {
559+
assertThat(p.operators(), hasSize(3));
560+
assertThat(p.operators().get(0).operator(), equalTo("ExchangeSourceOperator"));
561+
assertThat(p.operators().get(1).operator(), containsString("TimeSeriesAggregationOperator"));
562+
assertThat(p.operators().get(2).operator(), equalTo("ExchangeSinkOperator"));
563+
} else {
564+
assertThat(p.operators(), hasSize(4));
565+
assertThat(p.operators().get(0).operator(), equalTo("ExchangeSourceOperator"));
566+
assertThat(p.operators().get(1).operator(), containsString("TimeSeriesExtractFieldOperator"));
567+
assertThat(p.operators().get(2).operator(), containsString("EvalOperator"));
568+
assertThat(p.operators().get(3).operator(), equalTo("ExchangeSinkOperator"));
569+
}
508570
}
571+
assertThat(dataProfiles, hasSize(9));
572+
}
573+
}
574+
// non-rate aggregation is executed with multiple shards at a time
575+
{
576+
EsqlQueryRequest request = new EsqlQueryRequest();
577+
request.profile(true);
578+
request.pragmas(pragmas);
579+
request.acceptedPragmaRisks(true);
580+
request.query("TS my-hosts | STATS avg(avg_over_time(cpu)) BY cluster, bucket(@timestamp, 1minute) | SORT cluster");
581+
try (var resp = run(request)) {
582+
EsqlQueryResponse.Profile profile = resp.profile();
583+
List<DriverProfile> dataProfiles = profile.drivers().stream().filter(d -> d.description().equals("data")).toList();
584+
assertThat(dataProfiles, hasSize(1));
585+
List<OperatorStatus> ops = dataProfiles.get(0).operators();
586+
assertThat(ops, hasSize(5));
587+
assertThat(ops.get(0).operator(), containsString("LuceneSourceOperator"));
588+
assertThat(ops.get(0).status(), Matchers.instanceOf(LuceneSourceOperator.Status.class));
589+
LuceneSourceOperator.Status status = (LuceneSourceOperator.Status) ops.get(0).status();
590+
assertThat(status.processedShards(), hasSize(3));
591+
assertThat(ops.get(1).operator(), containsString("EvalOperator"));
592+
assertThat(ops.get(2).operator(), containsString("ValuesSourceReaderOperator"));
593+
assertThat(ops.get(3).operator(), containsString("TimeSeriesAggregationOperator"));
594+
assertThat(ops.get(4).operator(), containsString("ExchangeSinkOperator"));
509595
}
510-
assertThat(totalTimeSeries, equalTo(dataProfiles.size() / 3));
511596
}
512597
}
513598

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,15 @@ public static String[] planOriginalIndices(PhysicalPlan plan) {
159159
return indices.toArray(String[]::new);
160160
}
161161

162+
public static boolean requiresSortedTimeSeriesSource(PhysicalPlan plan) {
163+
return plan.anyMatch(e -> {
164+
if (e instanceof FragmentExec f) {
165+
return f.fragment().anyMatch(l -> l instanceof EsRelation r && r.indexMode() == IndexMode.TIME_SERIES);
166+
}
167+
return false;
168+
});
169+
}
170+
162171
private static void forEachRelation(PhysicalPlan plan, Consumer<EsRelation> action) {
163172
plan.forEachDown(FragmentExec.class, f -> f.fragment().forEachDown(EsRelation.class, r -> {
164173
if (r.indexMode() != IndexMode.LOOKUP) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ private class DataNodeRequestExecutor {
227227
private final ComputeListener computeListener;
228228
private final int maxConcurrentShards;
229229
private final ExchangeSink blockingSink; // block until we have completed on all shards or the coordinator has enough data
230+
private final boolean singleShardPipeline;
230231
private final boolean failFastOnShardFailure;
231232
private final Map<ShardId, Exception> shardLevelFailures;
232233

@@ -238,6 +239,7 @@ private class DataNodeRequestExecutor {
238239
int maxConcurrentShards,
239240
boolean failFastOnShardFailure,
240241
Map<ShardId, Exception> shardLevelFailures,
242+
boolean singleShardPipeline,
241243
ComputeListener computeListener
242244
) {
243245
this.flags = flags;
@@ -248,6 +250,7 @@ private class DataNodeRequestExecutor {
248250
this.maxConcurrentShards = maxConcurrentShards;
249251
this.failFastOnShardFailure = failFastOnShardFailure;
250252
this.shardLevelFailures = shardLevelFailures;
253+
this.singleShardPipeline = singleShardPipeline;
251254
this.blockingSink = exchangeSink.createExchangeSink(() -> {});
252255
}
253256

@@ -297,18 +300,37 @@ public void onFailure(Exception e) {
297300
batchListener.onResponse(DriverCompletionInfo.EMPTY);
298301
return;
299302
}
300-
var computeContext = new ComputeContext(
301-
sessionId,
302-
"data",
303-
clusterAlias,
304-
flags,
305-
searchContexts,
306-
configuration,
307-
configuration.newFoldContext(),
308-
null,
309-
() -> exchangeSink.createExchangeSink(pagesProduced::incrementAndGet)
310-
);
311-
computeService.runCompute(parentTask, computeContext, request.plan(), batchListener);
303+
if (singleShardPipeline) {
304+
try (ComputeListener sub = new ComputeListener(threadPool, () -> {}, batchListener)) {
305+
for (SearchContext searchContext : searchContexts) {
306+
var computeContext = new ComputeContext(
307+
sessionId,
308+
"data",
309+
clusterAlias,
310+
flags,
311+
List.of(searchContext),
312+
configuration,
313+
configuration.newFoldContext(),
314+
null,
315+
() -> exchangeSink.createExchangeSink(pagesProduced::incrementAndGet)
316+
);
317+
computeService.runCompute(parentTask, computeContext, request.plan(), sub.acquireCompute());
318+
}
319+
}
320+
} else {
321+
var computeContext = new ComputeContext(
322+
sessionId,
323+
"data",
324+
clusterAlias,
325+
flags,
326+
searchContexts,
327+
configuration,
328+
configuration.newFoldContext(),
329+
null,
330+
() -> exchangeSink.createExchangeSink(pagesProduced::incrementAndGet)
331+
);
332+
computeService.runCompute(parentTask, computeContext, request.plan(), batchListener);
333+
}
312334
}, batchListener::onFailure));
313335
}
314336

@@ -428,14 +450,21 @@ private void runComputeOnDataNode(
428450
exchangeService.finishSinkHandler(request.sessionId(), new TaskCancelledException(task.getReasonCancelled()));
429451
});
430452
EsqlFlags flags = computeService.createFlags();
453+
int maxConcurrentShards = request.pragmas().maxConcurrentShardsPerNode();
454+
final boolean sortedTimeSeriesSource = PlannerUtils.requiresSortedTimeSeriesSource(request.plan());
455+
if (sortedTimeSeriesSource) {
456+
// each time-series pipeline uses 3 drivers
457+
maxConcurrentShards = Math.clamp(Math.ceilDiv(request.pragmas().taskConcurrency(), 3), 1, maxConcurrentShards);
458+
}
431459
DataNodeRequestExecutor dataNodeRequestExecutor = new DataNodeRequestExecutor(
432460
flags,
433461
request,
434462
task,
435463
internalSink,
436-
request.configuration().pragmas().maxConcurrentShardsPerNode(),
464+
maxConcurrentShards,
437465
failFastOnShardFailure,
438466
shardLevelFailures,
467+
sortedTimeSeriesSource,
439468
computeListener
440469
);
441470
dataNodeRequestExecutor.start();

0 commit comments

Comments
 (0)