Skip to content

Commit 3e10648

Browse files
committed
Refactor driver factory
1 parent e05304b commit 3e10648

File tree

9 files changed

+151
-145
lines changed

9 files changed

+151
-145
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.operator;
9+
10+
import org.elasticsearch.compute.Describable;
11+
12+
public record DriverFactory(String sessionId, DriverSupplier driverSupplier, DriverParallelism driverParallelism) implements Describable {
13+
14+
@Override
15+
public String describe() {
16+
return "DriverFactory(instances = "
17+
+ driverParallelism.instanceCount()
18+
+ ", type = "
19+
+ driverParallelism.type()
20+
+ ")\n"
21+
+ driverSupplier.describe();
22+
}
23+
24+
public Driver createDriver() {
25+
return driverSupplier.create(sessionId);
26+
}
27+
28+
public interface DriverSupplier extends Describable {
29+
Driver create(String sessionId);
30+
}
31+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.operator;
9+
10+
/**
11+
* The count and type of driver parallelism.
12+
*/
13+
public record DriverParallelism(Type type, int instanceCount) {
14+
15+
public DriverParallelism {
16+
if (instanceCount <= 0) {
17+
throw new IllegalArgumentException("instance count must be greater than zero; got: " + instanceCount);
18+
}
19+
}
20+
21+
public static final DriverParallelism SINGLE = new DriverParallelism(Type.SINGLETON, 1);
22+
23+
public enum Type {
24+
SINGLETON,
25+
DATA_PARALLELISM,
26+
TASK_LEVEL_PARALLELISM
27+
}
28+
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverTaskRunner.java

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.action.support.ChannelActionListener;
1515
import org.elasticsearch.common.io.stream.StreamInput;
1616
import org.elasticsearch.common.io.stream.StreamOutput;
17+
import org.elasticsearch.core.Releasables;
1718
import org.elasticsearch.tasks.CancellableTask;
1819
import org.elasticsearch.tasks.Task;
1920
import org.elasticsearch.tasks.TaskId;
@@ -25,6 +26,7 @@
2526
import org.elasticsearch.transport.TransportService;
2627

2728
import java.io.IOException;
29+
import java.util.ArrayList;
2830
import java.util.List;
2931
import java.util.Map;
3032
import java.util.Objects;
@@ -42,7 +44,27 @@ public DriverTaskRunner(TransportService transportService, Executor executor) {
4244
transportService.registerRequestHandler(ACTION_NAME, executor, DriverRequest::new, new DriverRequestHandler(transportService));
4345
}
4446

45-
public void executeDrivers(Task parentTask, List<Driver> drivers, Executor executor, ActionListener<Void> listener) {
47+
public void executeDrivers(
48+
Task parentTask,
49+
List<DriverFactory> driverFactories,
50+
boolean captureProfiles,
51+
Executor executor,
52+
ActionListener<List<DriverProfile>> listener
53+
) {
54+
final List<Driver> drivers = new ArrayList<>();
55+
boolean success = false;
56+
try {
57+
for (DriverFactory df : driverFactories) {
58+
for (int i = 0; i < df.driverParallelism().instanceCount(); i++) {
59+
drivers.add(df.createDriver());
60+
}
61+
}
62+
success = true;
63+
} finally {
64+
if (success == false) {
65+
Releasables.close(Releasables.wrap(drivers));
66+
}
67+
}
4668
var runner = new DriverRunner(transportService.getThreadPool().getThreadContext()) {
4769
@Override
4870
protected void start(Driver driver, ActionListener<Void> driverListener) {
@@ -62,7 +84,13 @@ protected void start(Driver driver, ActionListener<Void> driverListener) {
6284
);
6385
}
6486
};
65-
runner.runToCompletion(drivers, listener);
87+
runner.runToCompletion(drivers, listener.map(unused -> {
88+
if (captureProfiles) {
89+
return drivers.stream().map(Driver::profile).toList();
90+
} else {
91+
return List.of();
92+
}
93+
}));
6694
}
6795

6896
private static class DriverRequest extends ActionRequest implements CompositeIndicesRequest {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.compute.lucene.LuceneTopNSourceOperator;
2424
import org.elasticsearch.compute.lucene.TimeSeriesSortedSourceOperatorFactory;
2525
import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator;
26+
import org.elasticsearch.compute.operator.DriverParallelism;
2627
import org.elasticsearch.compute.operator.Operator;
2728
import org.elasticsearch.compute.operator.OrdinalsGroupingOperator;
2829
import org.elasticsearch.compute.operator.SourceOperator;
@@ -59,7 +60,6 @@
5960
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
6061
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec.Sort;
6162
import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec;
62-
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.DriverParallelism;
6363
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.LocalExecutionPlannerContext;
6464
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.PhysicalOperation;
6565

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java

Lines changed: 18 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import org.elasticsearch.compute.operator.ColumnLoadOperator;
2424
import org.elasticsearch.compute.operator.Driver;
2525
import org.elasticsearch.compute.operator.DriverContext;
26+
import org.elasticsearch.compute.operator.DriverFactory;
27+
import org.elasticsearch.compute.operator.DriverParallelism;
2628
import org.elasticsearch.compute.operator.EvalOperator.EvalOperatorFactory;
2729
import org.elasticsearch.compute.operator.FilterOperator.FilterOperatorFactory;
2830
import org.elasticsearch.compute.operator.LimitOperator;
@@ -175,9 +177,8 @@ public LocalExecutionPlanner(
175177
/**
176178
* turn the given plan into a list of drivers to execute
177179
*/
178-
public LocalExecutionPlan plan(String taskDescription, FoldContext foldCtx, PhysicalPlan localPhysicalPlan) {
180+
public DriverFactory plan(String taskDescription, FoldContext foldCtx, PhysicalPlan localPhysicalPlan) {
179181
var context = new LocalExecutionPlannerContext(
180-
new ArrayList<>(),
181182
new Holder<>(DriverParallelism.SINGLE),
182183
configuration.pragmas(),
183184
bigArrays,
@@ -194,23 +195,20 @@ public LocalExecutionPlan plan(String taskDescription, FoldContext foldCtx, Phys
194195
PhysicalOperation physicalOperation = plan(localPhysicalPlan, context);
195196

196197
final TimeValue statusInterval = configuration.pragmas().statusInterval();
197-
context.addDriverFactory(
198-
new DriverFactory(
199-
new DriverSupplier(
200-
taskDescription,
201-
ClusterName.CLUSTER_NAME_SETTING.get(settings).value(),
202-
Node.NODE_NAME_SETTING.get(settings),
203-
context.bigArrays,
204-
context.blockFactory,
205-
physicalOperation,
206-
statusInterval,
207-
settings
208-
),
209-
context.driverParallelism().get()
210-
)
198+
return new DriverFactory(
199+
sessionId,
200+
new DriverSupplier(
201+
taskDescription,
202+
ClusterName.CLUSTER_NAME_SETTING.get(settings).value(),
203+
Node.NODE_NAME_SETTING.get(settings),
204+
context.bigArrays,
205+
context.blockFactory,
206+
physicalOperation,
207+
statusInterval,
208+
settings
209+
),
210+
context.driverParallelism().get()
211211
);
212-
213-
return new LocalExecutionPlan(context.driverFactories);
214212
}
215213

216214
private PhysicalOperation plan(PhysicalPlan node, LocalExecutionPlannerContext context) {
@@ -815,42 +813,18 @@ public String toString() {
815813
}
816814
}
817815

818-
/**
819-
* The count and type of driver parallelism.
820-
*/
821-
record DriverParallelism(Type type, int instanceCount) {
822-
823-
DriverParallelism {
824-
if (instanceCount <= 0) {
825-
throw new IllegalArgumentException("instance count must be greater than zero; got: " + instanceCount);
826-
}
827-
}
828-
829-
static final DriverParallelism SINGLE = new DriverParallelism(Type.SINGLETON, 1);
830-
831-
enum Type {
832-
SINGLETON,
833-
DATA_PARALLELISM,
834-
TASK_LEVEL_PARALLELISM
835-
}
836-
}
837-
838816
/**
839817
* Context object used while generating a local plan. Currently only collects the driver factories as well as
840818
* maintains information how many driver instances should be created for a given driver.
841819
*/
842820
public record LocalExecutionPlannerContext(
843-
List<DriverFactory> driverFactories,
844821
Holder<DriverParallelism> driverParallelism,
845822
QueryPragmas queryPragmas,
846823
BigArrays bigArrays,
847824
BlockFactory blockFactory,
848825
FoldContext foldCtx,
849826
Settings settings
850827
) {
851-
void addDriverFactory(DriverFactory driverFactory) {
852-
driverFactories.add(driverFactory);
853-
}
854828

855829
void driverParallelism(DriverParallelism parallelism) {
856830
driverParallelism.set(parallelism);
@@ -879,9 +853,9 @@ record DriverSupplier(
879853
PhysicalOperation physicalOperation,
880854
TimeValue statusInterval,
881855
Settings settings
882-
) implements Function<String, Driver>, Describable {
856+
) implements DriverFactory.DriverSupplier {
883857
@Override
884-
public Driver apply(String sessionId) {
858+
public Driver create(String sessionId) {
885859
SourceOperator source = null;
886860
List<Operator> operators = new ArrayList<>();
887861
SinkOperator sink = null;
@@ -925,51 +899,4 @@ public String describe() {
925899
return physicalOperation.describe();
926900
}
927901
}
928-
929-
record DriverFactory(DriverSupplier driverSupplier, DriverParallelism driverParallelism) implements Describable {
930-
@Override
931-
public String describe() {
932-
return "DriverFactory(instances = "
933-
+ driverParallelism.instanceCount()
934-
+ ", type = "
935-
+ driverParallelism.type()
936-
+ ")\n"
937-
+ driverSupplier.describe();
938-
}
939-
}
940-
941-
/**
942-
* Plan representation that is geared towards execution on a single node
943-
*/
944-
public static class LocalExecutionPlan implements Describable {
945-
final List<DriverFactory> driverFactories;
946-
947-
LocalExecutionPlan(List<DriverFactory> driverFactories) {
948-
this.driverFactories = driverFactories;
949-
}
950-
951-
public List<Driver> createDrivers(String sessionId) {
952-
List<Driver> drivers = new ArrayList<>();
953-
boolean success = false;
954-
try {
955-
for (DriverFactory df : driverFactories) {
956-
for (int i = 0; i < df.driverParallelism.instanceCount; i++) {
957-
logger.trace("building {} {}", i, df);
958-
drivers.add(df.driverSupplier.apply(sessionId));
959-
}
960-
}
961-
success = true;
962-
return drivers;
963-
} finally {
964-
if (success == false) {
965-
Releasables.close(Releasables.wrap(drivers));
966-
}
967-
}
968-
}
969-
970-
@Override
971-
public String describe() {
972-
return driverFactories.stream().map(DriverFactory::describe).collect(joining("\n"));
973-
}
974-
}
975902
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import org.elasticsearch.common.util.concurrent.RunOnce;
1717
import org.elasticsearch.compute.data.BlockFactory;
1818
import org.elasticsearch.compute.data.Page;
19-
import org.elasticsearch.compute.operator.Driver;
19+
import org.elasticsearch.compute.operator.DriverFactory;
2020
import org.elasticsearch.compute.operator.DriverProfile;
2121
import org.elasticsearch.compute.operator.DriverTaskRunner;
2222
import org.elasticsearch.compute.operator.exchange.ExchangeService;
@@ -356,7 +356,7 @@ public SourceProvider createSourceProvider() {
356356
new EsPhysicalOperationProviders.DefaultShardContext(i, searchExecutionContext, searchContext.request().getAliasFilter())
357357
);
358358
}
359-
final List<Driver> drivers;
359+
final DriverFactory driverFactory;
360360
try {
361361
LocalExecutionPlanner planner = new LocalExecutionPlanner(
362362
context.sessionId(),
@@ -373,39 +373,22 @@ public SourceProvider createSourceProvider() {
373373
new EsPhysicalOperationProviders(context.foldCtx(), contexts, searchService.getIndicesService().getAnalysis()),
374374
contexts
375375
);
376-
377376
LOGGER.debug("Received physical plan:\n{}", plan);
378-
379377
plan = PlannerUtils.localPlan(context.searchExecutionContexts(), context.configuration(), context.foldCtx(), plan);
380-
// the planner will also set the driver parallelism in LocalExecutionPlanner.LocalExecutionPlan (used down below)
381-
// it's doing this in the planning of EsQueryExec (the source of the data)
382-
// see also EsPhysicalOperationProviders.sourcePhysicalOperation
383-
LocalExecutionPlanner.LocalExecutionPlan localExecutionPlan = planner.plan(context.taskDescription(), context.foldCtx(), plan);
378+
driverFactory = planner.plan(context.taskDescription(), context.foldCtx(), plan);
384379
if (LOGGER.isDebugEnabled()) {
385-
LOGGER.debug("Local execution plan:\n{}", localExecutionPlan.describe());
386-
}
387-
drivers = localExecutionPlan.createDrivers(context.sessionId());
388-
if (drivers.isEmpty()) {
389-
throw new IllegalStateException("no drivers created");
380+
LOGGER.debug("driver factory:\n{}", driverFactory.describe());
390381
}
391-
LOGGER.debug("using {} drivers", drivers.size());
392382
} catch (Exception e) {
393383
listener.onFailure(e);
394384
return;
395385
}
396-
ActionListener<Void> listenerCollectingStatus = listener.map(ignored -> {
397-
if (context.configuration().profile()) {
398-
return drivers.stream().map(Driver::profile).toList();
399-
} else {
400-
return List.of();
401-
}
402-
});
403-
listenerCollectingStatus = ActionListener.releaseAfter(listenerCollectingStatus, () -> Releasables.close(drivers));
404386
driverRunner.executeDrivers(
405387
task,
406-
drivers,
388+
List.of(driverFactory),
389+
context.configuration().profile(),
407390
transportService.getThreadPool().executor(ESQL_WORKER_THREAD_POOL_NAME),
408-
listenerCollectingStatus
391+
listener
409392
);
410393
}
411394

0 commit comments

Comments
 (0)