Skip to content

Commit 8aba64f

Browse files
committed
Remove channels from ToAggregator method
1 parent c3baa17 commit 8aba64f

File tree

20 files changed

+66
-84
lines changed

20 files changed

+66
-84
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountAggregatorFunction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import java.util.List;
2020

2121
public class CountAggregatorFunction implements AggregatorFunction {
22-
public static AggregatorFunctionSupplier supplier(List<Integer> channels) {
22+
public static AggregatorFunctionSupplier supplier() {
2323
return new AggregatorFunctionSupplier() {
2424
@Override
2525
public List<IntermediateStateDesc> nonGroupingIntermediateStateDesc() {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Count.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ public DataType dataType() {
126126
}
127127

128128
@Override
129-
public AggregatorFunctionSupplier supplier(List<Integer> inputChannels) {
130-
return CountAggregatorFunction.supplier(inputChannels);
129+
public AggregatorFunctionSupplier supplier() {
130+
return CountAggregatorFunction.supplier();
131131
}
132132

133133
@Override

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinct.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ protected TypeResolution resolveType() {
209209
}
210210

211211
@Override
212-
public AggregatorFunctionSupplier supplier(List<Integer> inputChannels) {
212+
public AggregatorFunctionSupplier supplier() {
213213
DataType type = field().dataType();
214214
int precision = this.precision == null
215215
? DEFAULT_PRECISION

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/FromPartial.java

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,8 @@ public FromPartial withFilter(Expression filter) {
111111
}
112112

113113
@Override
114-
public AggregatorFunctionSupplier supplier(List<Integer> inputChannels) {
115-
final ToAggregator toAggregator = (ToAggregator) function;
116-
if (inputChannels.size() != 1) {
117-
assert false : "from_partial aggregation requires exactly one input channel; got " + inputChannels;
118-
throw new IllegalArgumentException("from_partial aggregation requires exactly one input channel; got " + inputChannels);
119-
}
120-
final int inputChannel = inputChannels.get(0);
114+
public AggregatorFunctionSupplier supplier() {
115+
final AggregatorFunctionSupplier supplier = ((ToAggregator) function).supplier();
121116
return new AggregatorFunctionSupplier() {
122117
@Override
123118
public List<IntermediateStateDesc> nonGroupingIntermediateStateDesc() {
@@ -143,17 +138,17 @@ public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext
143138

144139
@Override
145140
public Aggregator.Factory aggregatorFactory(AggregatorMode mode, List<Integer> channels) {
146-
final AggregatorFunctionSupplier supplier;
147-
// TODO: Improve this code; We don't need to create an aggregator now
148-
try (var dummy = toAggregator.supplier(inputChannels).aggregator(DriverContext.getLocalDriver(), channels)) {
149-
var intermediateChannels = IntStream.range(0, dummy.intermediateBlockCount()).boxed().toList();
150-
supplier = toAggregator.supplier(intermediateChannels);
141+
if (channels.size() != 1) {
142+
assert false : "from_partial aggregation requires exactly one input channel; got " + channels;
143+
throw new IllegalArgumentException("from_partial aggregation requires exactly one input channel; got " + channels);
151144
}
145+
final int inputChannel = channels.get(0);
146+
var intermediateChannels = IntStream.range(0, supplier.nonGroupingIntermediateStateDesc().size()).boxed().toList();
152147
return new Aggregator.Factory() {
153148
@Override
154149
public Aggregator apply(DriverContext driverContext) {
155150
// use groupingAggregator since we can receive intermediate output from a grouping aggregate
156-
final var groupingAggregator = supplier.groupingAggregator(driverContext, channels);
151+
final var groupingAggregator = supplier.groupingAggregator(driverContext, intermediateChannels);
157152
return new Aggregator(new FromPartialAggregatorFunction(driverContext, groupingAggregator, inputChannel), mode);
158153
}
159154

@@ -166,15 +161,16 @@ public String describe() {
166161

167162
@Override
168163
public GroupingAggregator.Factory groupingAggregatorFactory(AggregatorMode mode, List<Integer> channels) {
169-
final AggregatorFunctionSupplier supplier;
170-
try (var dummy = toAggregator.supplier(inputChannels).aggregator(DriverContext.getLocalDriver(), channels)) {
171-
var intermediateChannels = IntStream.range(0, dummy.intermediateBlockCount()).boxed().toList();
172-
supplier = toAggregator.supplier(intermediateChannels);
164+
if (channels.size() != 1) {
165+
assert false : "from_partial aggregation requires exactly one input channel; got " + channels;
166+
throw new IllegalArgumentException("from_partial aggregation requires exactly one input channel; got " + channels);
173167
}
168+
final int inputChannel = channels.get(0);
169+
var intermediateChannels = IntStream.range(0, supplier.nonGroupingIntermediateStateDesc().size()).boxed().toList();
174170
return new GroupingAggregator.Factory() {
175171
@Override
176172
public GroupingAggregator apply(DriverContext driverContext) {
177-
final GroupingAggregatorFunction aggregator = supplier.groupingAggregator(driverContext, channels);
173+
final GroupingAggregatorFunction aggregator = supplier.groupingAggregator(driverContext, intermediateChannels);
178174
return new GroupingAggregator(new FromPartialGroupingAggregatorFunction(aggregator, inputChannel), mode);
179175
}
180176

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ public DataType dataType() {
141141
}
142142

143143
@Override
144-
public final AggregatorFunctionSupplier supplier(List<Integer> inputChannels) {
144+
public final AggregatorFunctionSupplier supplier() {
145145
DataType type = field().dataType();
146146
if (SUPPLIERS.containsKey(type) == false) {
147147
// If the type checking did its job, this should never happen

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviation.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,17 @@ public MedianAbsoluteDeviation withFilter(Expression filter) {
9999
}
100100

101101
@Override
102-
protected AggregatorFunctionSupplier longSupplier(List<Integer> inputChannels) {
102+
protected AggregatorFunctionSupplier longSupplier() {
103103
return new MedianAbsoluteDeviationLongAggregatorFunctionSupplier();
104104
}
105105

106106
@Override
107-
protected AggregatorFunctionSupplier intSupplier(List<Integer> inputChannels) {
107+
protected AggregatorFunctionSupplier intSupplier() {
108108
return new MedianAbsoluteDeviationIntAggregatorFunctionSupplier();
109109
}
110110

111111
@Override
112-
protected AggregatorFunctionSupplier doubleSupplier(List<Integer> inputChannels) {
112+
protected AggregatorFunctionSupplier doubleSupplier() {
113113
return new MedianAbsoluteDeviationDoubleAggregatorFunctionSupplier();
114114
}
115115

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ public DataType dataType() {
141141
}
142142

143143
@Override
144-
public final AggregatorFunctionSupplier supplier(List<Integer> inputChannels) {
144+
public final AggregatorFunctionSupplier supplier() {
145145
DataType type = field().dataType();
146146
if (SUPPLIERS.containsKey(type) == false) {
147147
// If the type checking did its job, this should never happen

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/NumericAggregate.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,26 +92,26 @@ public DataType dataType() {
9292
}
9393

9494
@Override
95-
public final AggregatorFunctionSupplier supplier(List<Integer> inputChannels) {
95+
public final AggregatorFunctionSupplier supplier() {
9696
DataType type = field().dataType();
9797
if (supportsDates() && type == DataType.DATETIME) {
98-
return longSupplier(inputChannels);
98+
return longSupplier();
9999
}
100100
if (type == DataType.LONG) {
101-
return longSupplier(inputChannels);
101+
return longSupplier();
102102
}
103103
if (type == DataType.INTEGER) {
104-
return intSupplier(inputChannels);
104+
return intSupplier();
105105
}
106106
if (type == DataType.DOUBLE) {
107-
return doubleSupplier(inputChannels);
107+
return doubleSupplier();
108108
}
109109
throw EsqlIllegalArgumentException.illegalDataType(type);
110110
}
111111

112-
protected abstract AggregatorFunctionSupplier longSupplier(List<Integer> inputChannels);
112+
protected abstract AggregatorFunctionSupplier longSupplier();
113113

114-
protected abstract AggregatorFunctionSupplier intSupplier(List<Integer> inputChannels);
114+
protected abstract AggregatorFunctionSupplier intSupplier();
115115

116-
protected abstract AggregatorFunctionSupplier doubleSupplier(List<Integer> inputChannels);
116+
protected abstract AggregatorFunctionSupplier doubleSupplier();
117117
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Percentile.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,17 +156,17 @@ protected TypeResolution resolveType() {
156156
}
157157

158158
@Override
159-
protected AggregatorFunctionSupplier longSupplier(List<Integer> inputChannels) {
159+
protected AggregatorFunctionSupplier longSupplier() {
160160
return new PercentileLongAggregatorFunctionSupplier(percentileValue());
161161
}
162162

163163
@Override
164-
protected AggregatorFunctionSupplier intSupplier(List<Integer> inputChannels) {
164+
protected AggregatorFunctionSupplier intSupplier() {
165165
return new PercentileIntAggregatorFunctionSupplier(percentileValue());
166166
}
167167

168168
@Override
169-
protected AggregatorFunctionSupplier doubleSupplier(List<Integer> inputChannels) {
169+
protected AggregatorFunctionSupplier doubleSupplier() {
170170
return new PercentileDoubleAggregatorFunctionSupplier(percentileValue());
171171
}
172172

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Rate.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,7 @@ long unitInMillis() {
168168
}
169169

170170
@Override
171-
public AggregatorFunctionSupplier supplier(List<Integer> inputChannels) {
172-
if (inputChannels.size() != 2 && inputChannels.size() != 3) {
173-
throw new IllegalArgumentException("rate requires two for raw input or three channels for partial input; got " + inputChannels);
174-
}
171+
public AggregatorFunctionSupplier supplier() {
175172
final long unitInMillis = unitInMillis();
176173
final DataType type = field().dataType();
177174
return switch (type) {

0 commit comments

Comments
 (0)