diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LossySumDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LossySumDoubleAggregatorFunction.java new file mode 100644 index 0000000000000..2f682c14f161e --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LossySumDoubleAggregatorFunction.java @@ -0,0 +1,193 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link LossySumDoubleAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class LossySumDoubleAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("value", ElementType.DOUBLE), + new IntermediateStateDesc("unusedDeltas", ElementType.DOUBLE), + new IntermediateStateDesc("seen", ElementType.BOOLEAN) ); + + private final DriverContext driverContext; + + private final LossySumDoubleAggregator.SumState state; + + private final List channels; + + public LossySumDoubleAggregatorFunction(DriverContext driverContext, List channels, + LossySumDoubleAggregator.SumState state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static LossySumDoubleAggregatorFunction create(DriverContext driverContext, + List channels) { + return new LossySumDoubleAggregatorFunction(driverContext, channels, LossySumDoubleAggregator.initSingle()); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + DoubleBlock vBlock = page.getBlock(channels.get(0)); + DoubleVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, mask); + return; + } + addRawVector(vVector, mask); + } + + private void addRawInputNotMasked(Page page) { + DoubleBlock vBlock = page.getBlock(channels.get(0)); + DoubleVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock); + return; + } + addRawVector(vVector); + } + + private void addRawVector(DoubleVector vVector) { + state.seen(true); + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + double vValue = vVector.getDouble(valuesPosition); + LossySumDoubleAggregator.combine(state, vValue); + } + } + + private void addRawVector(DoubleVector vVector, BooleanVector mask) { + state.seen(true); + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + double vValue = vVector.getDouble(valuesPosition); + LossySumDoubleAggregator.combine(state, vValue); + } + } + + private void addRawBlock(DoubleBlock vBlock) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (vBlock.isNull(p)) { + continue; + } + state.seen(true); + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vBlock.getValueCount(p); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + LossySumDoubleAggregator.combine(state, vValue); + } + } + } + + private void addRawBlock(DoubleBlock vBlock, BooleanVector mask) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + if (vBlock.isNull(p)) { + continue; + } + state.seen(true); + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vBlock.getValueCount(p); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + LossySumDoubleAggregator.combine(state, vValue); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block valueUncast = page.getBlock(channels.get(0)); + if (valueUncast.areAllValuesNull()) { + return; + } + DoubleVector value = ((DoubleBlock) valueUncast).asVector(); + assert value.getPositionCount() == 1; + Block unusedDeltasUncast = page.getBlock(channels.get(1)); + if (unusedDeltasUncast.areAllValuesNull()) { + return; + } + DoubleVector unusedDeltas = ((DoubleBlock) unusedDeltasUncast).asVector(); + assert unusedDeltas.getPositionCount() == 1; + Block seenUncast = page.getBlock(channels.get(2)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert seen.getPositionCount() == 1; + LossySumDoubleAggregator.combineIntermediate(state, value.getDouble(0), unusedDeltas.getDouble(0), seen.getBoolean(0)); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + if (state.seen() == false) { + blocks[offset] = driverContext.blockFactory().newConstantNullBlock(1); + return; + } + blocks[offset] = LossySumDoubleAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LossySumDoubleAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LossySumDoubleAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..9d7981a3869a6 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LossySumDoubleAggregatorFunctionSupplier.java @@ -0,0 +1,47 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link LossySumDoubleAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class LossySumDoubleAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + public LossySumDoubleAggregatorFunctionSupplier() { + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return LossySumDoubleAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return LossySumDoubleGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public LossySumDoubleAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return LossySumDoubleAggregatorFunction.create(driverContext, channels); + } + + @Override + public LossySumDoubleGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return LossySumDoubleGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "lossy_sum of doubles"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LossySumDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LossySumDoubleGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..bfe214035260e --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/LossySumDoubleGroupingAggregatorFunction.java @@ -0,0 +1,344 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link LossySumDoubleAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class LossySumDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("value", ElementType.DOUBLE), + new IntermediateStateDesc("unusedDeltas", ElementType.DOUBLE), + new IntermediateStateDesc("seen", ElementType.BOOLEAN) ); + + private final LossySumDoubleAggregator.GroupingSumState state; + + private final List channels; + + private final DriverContext driverContext; + + public LossySumDoubleGroupingAggregatorFunction(List channels, + LossySumDoubleAggregator.GroupingSumState state, DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static LossySumDoubleGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new LossySumDoubleGroupingAggregatorFunction(channels, LossySumDoubleAggregator.initGrouping(driverContext.bigArrays()), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + DoubleBlock vBlock = page.getBlock(channels.get(0)); + DoubleVector vVector = vBlock.asVector(); + if (vVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock vBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + LossySumDoubleAggregator.combine(state, groupId, vValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector vVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + double vValue = vVector.getDouble(valuesPosition); + LossySumDoubleAggregator.combine(state, groupId, vValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valueUncast = page.getBlock(channels.get(0)); + if (valueUncast.areAllValuesNull()) { + return; + } + DoubleVector value = ((DoubleBlock) valueUncast).asVector(); + Block unusedDeltasUncast = page.getBlock(channels.get(1)); + if (unusedDeltasUncast.areAllValuesNull()) { + return; + } + DoubleVector unusedDeltas = ((DoubleBlock) unusedDeltasUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(2)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert value.getPositionCount() == unusedDeltas.getPositionCount() && value.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + LossySumDoubleAggregator.combineIntermediate(state, groupId, value.getDouble(valuesPosition), unusedDeltas.getDouble(valuesPosition), seen.getBoolean(valuesPosition)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock vBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + LossySumDoubleAggregator.combine(state, groupId, vValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector vVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + double vValue = vVector.getDouble(valuesPosition); + LossySumDoubleAggregator.combine(state, groupId, vValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valueUncast = page.getBlock(channels.get(0)); + if (valueUncast.areAllValuesNull()) { + return; + } + DoubleVector value = ((DoubleBlock) valueUncast).asVector(); + Block unusedDeltasUncast = page.getBlock(channels.get(1)); + if (unusedDeltasUncast.areAllValuesNull()) { + return; + } + DoubleVector unusedDeltas = ((DoubleBlock) unusedDeltasUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(2)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert value.getPositionCount() == unusedDeltas.getPositionCount() && value.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + LossySumDoubleAggregator.combineIntermediate(state, groupId, value.getDouble(valuesPosition), unusedDeltas.getDouble(valuesPosition), seen.getBoolean(valuesPosition)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock vBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + LossySumDoubleAggregator.combine(state, groupId, vValue); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleVector vVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + double vValue = vVector.getDouble(valuesPosition); + LossySumDoubleAggregator.combine(state, groupId, vValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block valueUncast = page.getBlock(channels.get(0)); + if (valueUncast.areAllValuesNull()) { + return; + } + DoubleVector value = ((DoubleBlock) valueUncast).asVector(); + Block unusedDeltasUncast = page.getBlock(channels.get(1)); + if (unusedDeltasUncast.areAllValuesNull()) { + return; + } + DoubleVector unusedDeltas = ((DoubleBlock) unusedDeltasUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(2)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert value.getPositionCount() == unusedDeltas.getPositionCount() && value.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + LossySumDoubleAggregator.combineIntermediate(state, groupId, value.getDouble(valuesPosition), unusedDeltas.getDouble(valuesPosition), seen.getBoolean(valuesPosition)); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, DoubleBlock vBlock) { + if (vBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = LossySumDoubleAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/LossySumDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/LossySumDoubleAggregator.java new file mode 100644 index 0000000000000..7a041169d4c0f --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/LossySumDoubleAggregator.java @@ -0,0 +1,177 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.DoubleArray; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +@Aggregator( + { + @IntermediateState(name = "value", type = "DOUBLE"), + // Unlike the compensated sum, the lossy sum does not use deltas. This unused deltas block is padded for alignment + // with the compensated sum, allow the summation options are exposed to users without worrying about BWC issues. + @IntermediateState(name = "unusedDeltas", type = "DOUBLE"), + @IntermediateState(name = "seen", type = "BOOLEAN") } +) +@GroupingAggregator +class LossySumDoubleAggregator { + + public static SumState initSingle() { + return new SumState(); + } + + public static void combine(SumState current, double v) { + current.value += v; + } + + public static void combine(SumState current, double value, double unusedDelta) { + assert unusedDelta == 0.0 : "Lossy sum should not have delta " + unusedDelta; + current.value += value; + } + + public static void combineIntermediate(SumState state, double inValue, double unusedDelta, boolean seen) { + // Unlike the compensated sum, the lossy sum does not use deltas. This unused deltas block is padded for alignment + // with the compensated sum, allow the summation options are exposed to users without worrying about BWC issues. + assert unusedDelta == 0.0 : "Lossy sum should not have delta " + unusedDelta; + if (seen) { + combine(state, inValue); + state.seen(true); + } + } + + public static void evaluateIntermediate(SumState state, DriverContext driverContext, Block[] blocks, int offset) { + assert blocks.length >= offset + 3; + BlockFactory blockFactory = driverContext.blockFactory(); + blocks[offset + 0] = blockFactory.newConstantDoubleBlockWith(state.value, 1); + blocks[offset + 1] = blockFactory.newConstantDoubleBlockWith(0.0, 1); + blocks[offset + 2] = blockFactory.newConstantBooleanBlockWith(state.seen(), 1); + } + + public static Block evaluateFinal(SumState state, DriverContext driverContext) { + return driverContext.blockFactory().newConstantDoubleBlockWith(state.value, 1); + } + + public static GroupingSumState initGrouping(BigArrays bigArrays) { + return new GroupingSumState(bigArrays); + } + + public static void combine(GroupingSumState current, int groupId, double v) { + current.add(v, groupId); + } + + public static void combineIntermediate(GroupingSumState current, int groupId, double value, double zeroDelta, boolean seen) { + assert zeroDelta == 0.0 : zeroDelta; + if (seen) { + current.add(value, groupId); + } + } + + public static void evaluateIntermediate( + GroupingSumState state, + Block[] blocks, + int offset, + IntVector selected, + DriverContext driverContext + ) { + assert blocks.length >= offset + 3; + try ( + var valuesBuilder = driverContext.blockFactory().newDoubleVectorFixedBuilder(selected.getPositionCount()); + var seenBuilder = driverContext.blockFactory().newBooleanVectorFixedBuilder(selected.getPositionCount()) + ) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int group = selected.getInt(i); + if (group < state.values.size()) { + valuesBuilder.appendDouble(state.values.get(group)); + } else { + valuesBuilder.appendDouble(0); + } + seenBuilder.appendBoolean(state.hasValue(group)); + } + blocks[offset + 0] = valuesBuilder.build().asBlock(); + blocks[offset + 1] = driverContext.blockFactory().newConstantDoubleBlockWith(0.0, selected.getPositionCount()); + blocks[offset + 2] = seenBuilder.build().asBlock(); + } + } + + public static Block evaluateFinal(GroupingSumState state, IntVector selected, GroupingAggregatorEvaluationContext ctx) { + try (DoubleBlock.Builder builder = ctx.blockFactory().newDoubleBlockBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int si = selected.getInt(i); + if (state.hasValue(si) && si < state.values.size()) { + builder.appendDouble(state.values.get(si)); + } else { + builder.appendNull(); + } + } + return builder.build(); + } + } + + static final class SumState implements AggregatorState { + private boolean seen; + double value; + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + LossySumDoubleAggregator.evaluateIntermediate(this, driverContext, blocks, offset); + } + + @Override + public void close() {} + + public boolean seen() { + return seen; + } + + public void seen(boolean seen) { + this.seen = seen; + } + } + + static final class GroupingSumState extends AbstractArrayState implements GroupingAggregatorState { + private DoubleArray values; + + GroupingSumState(BigArrays bigArrays) { + super(bigArrays); + boolean success = false; + try { + this.values = bigArrays.newDoubleArray(1); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + void add(double valueToAdd, int groupId) { + values = bigArrays.grow(values, groupId + 1); + values.increment(groupId, valueToAdd); + trackGroupId(groupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + LossySumDoubleAggregator.evaluateIntermediate(this, blocks, offset, selected, driverContext); + } + + @Override + public void close() { + Releasables.close(values, super::close); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/LossySumDoubleAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/LossySumDoubleAggregatorFunctionTests.java new file mode 100644 index 0000000000000..3bc9e623dcef4 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/LossySumDoubleAggregatorFunctionTests.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.SequenceDoubleBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.test.TestDriverFactory; +import org.elasticsearch.compute.test.TestResultPageSinkOperator; +import org.elasticsearch.test.ESTestCase; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.DoubleStream; +import java.util.stream.LongStream; + +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; + +public class LossySumDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase { + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + return new SequenceDoubleBlockSourceOperator(blockFactory, LongStream.range(0, size).mapToDouble(l -> ESTestCase.randomDouble())); + } + + @Override + protected AggregatorFunctionSupplier aggregatorFunction() { + return new LossySumDoubleAggregatorFunctionSupplier(); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "lossy_sum of doubles"; + } + + @Override + protected void assertSimpleOutput(List input, Block result) { + double sum = input.stream().flatMapToDouble(p -> allDoubles(p.getBlock(0))).sum(); + assertThat(((DoubleBlock) result).getDouble(0), closeTo(sum, .0001)); + } + + public void testOverflowSucceeds() { + DriverContext driverContext = driverContext(); + List results = new ArrayList<>(); + try ( + Driver d = TestDriverFactory.create( + driverContext, + new SequenceDoubleBlockSourceOperator(driverContext.blockFactory(), DoubleStream.of(Double.MAX_VALUE - 1, 2)), + List.of(simple().get(driverContext)), + new TestResultPageSinkOperator(results::add) + ) + ) { + runDriver(d); + } + assertThat(results.get(0).getBlock(0).getDouble(0), equalTo(Double.MAX_VALUE + 1)); + assertDriverContext(driverContext); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/LossySumDoubleGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/LossySumDoubleGroupingAggregatorFunctionTests.java new file mode 100644 index 0000000000000..8a3a5b09bdc80 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/LossySumDoubleGroupingAggregatorFunctionTests.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.LongDoubleTupleBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.core.Tuple; + +import java.util.List; +import java.util.stream.LongStream; + +import static org.hamcrest.Matchers.closeTo; + +public class LossySumDoubleGroupingAggregatorFunctionTests extends GroupingAggregatorFunctionTestCase { + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int end) { + return new LongDoubleTupleBlockSourceOperator( + blockFactory, + LongStream.range(0, end).mapToObj(l -> Tuple.tuple(randomLongBetween(0, 4), randomDouble())) + ); + } + + @Override + protected AggregatorFunctionSupplier aggregatorFunction() { + return new LossySumDoubleAggregatorFunctionSupplier(); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "lossy_sum of doubles"; + } + + @Override + protected void assertSimpleGroup(List input, Block result, int position, Long group) { + double sum = input.stream().flatMapToDouble(p -> allDoubles(p, group)).sum(); + // Won't precisely match in distributed case but will be close + assertThat(((DoubleBlock) result).getDouble(position), closeTo(sum, 0.01)); + } +} diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesIT.java index f7833b917b746..6621c29921ad4 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesIT.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.compute.lucene.TimeSeriesSourceOperator; import org.elasticsearch.compute.operator.DriverProfile; +import org.elasticsearch.compute.operator.OperatorStatus; import org.elasticsearch.compute.operator.TimeSeriesAggregationOperator; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -28,8 +29,10 @@ import static org.elasticsearch.index.mapper.DateFieldMapper.DEFAULT_DATE_TIME_FORMATTER; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.not; public class TimeSeriesIT extends AbstractEsqlIntegTestCase { @@ -479,7 +482,7 @@ public void testFieldDoesNotExist() { } } - public void testProfile() { + public void testRateProfile() { EsqlQueryRequest request = new EsqlQueryRequest(); request.profile(true); request.query("TS hosts | STATS sum(rate(request_count)) BY cluster, bucket(@timestamp, 1minute) | SORT cluster"); @@ -508,4 +511,24 @@ public void testProfile() { assertThat(totalTimeSeries, equalTo(dataProfiles.size() / 3)); } } + + public void testAvgOrSumOverTimeProfile() { + EsqlQueryRequest request = new EsqlQueryRequest(); + request.profile(true); + String tsFunction = randomFrom("sum_over_time", "avg_over_time"); + request.query("TS hosts | STATS AVG(" + tsFunction + "(cpu)) BY cluster, bucket(@timestamp, 1minute) | SORT cluster"); + try (var resp = run(request)) { + EsqlQueryResponse.Profile profile = resp.profile(); + List dataProfiles = profile.drivers().stream().filter(d -> d.description().equals("data")).toList(); + assertThat(dataProfiles, not(empty())); + for (DriverProfile p : dataProfiles) { + List aggregatorOperators = p.operators() + .stream() + .filter(s -> s.status() instanceof TimeSeriesAggregationOperator.Status) + .toList(); + assertThat(aggregatorOperators, hasSize(1)); + assertThat(aggregatorOperators.getFirst().operator(), containsString("LossySumDoubleGroupingAggregatorFunction")); + } + } + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java index 62554570ba3c0..60ae8a4ba57c7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java @@ -16,7 +16,9 @@ import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.expression.function.Function; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.CollectionUtils; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; @@ -68,6 +70,39 @@ protected AggregateFunction(StreamInput in) throws IOException { ); } + /** + * Read a generic AggregateFunction from the stream input. This is used for BWC when the subclass requires a generic instance; + * then convert the parameters to the specific ones. + */ + protected static AggregateFunction readGenericAggregateFunction(StreamInput in) throws IOException { + return new AggregateFunction(in) { + @Override + public AggregateFunction withFilter(Expression filter) { + throw new UnsupportedOperationException(); + } + + @Override + public DataType dataType() { + throw new UnsupportedOperationException(); + } + + @Override + public Expression replaceChildren(List newChildren) { + throw new UnsupportedOperationException(); + } + + @Override + protected NodeInfo info() { + throw new UnsupportedOperationException(); + } + + @Override + public String getWriteableName() { + throw new UnsupportedOperationException(); + } + }; + } + @Override public final void writeTo(StreamOutput out) throws IOException { source().writeTo(out); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java index 931321bab4b1b..f9f5bcf23ce7d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java @@ -25,13 +25,13 @@ import java.io.IOException; import java.util.List; -import static java.util.Collections.emptyList; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; import static org.elasticsearch.xpack.esql.core.type.DataType.AGGREGATE_METRIC_DOUBLE; public class Avg extends AggregateFunction implements SurrogateExpression { - public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Avg", Avg::new); + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Avg", Avg::readFrom); + private final Expression summationMode; @FunctionInfo( returnType = "double", @@ -55,11 +55,16 @@ public Avg( description = "Expression that outputs values to average." ) Expression field ) { - this(source, field, Literal.TRUE); + this(source, field, Literal.TRUE, SummationMode.COMPENSATED_LITERAL); } - public Avg(Source source, Expression field, Expression filter) { - super(source, field, filter, emptyList()); + public Avg(Source source, Expression field, Expression filter, Expression summationMode) { + super(source, field, filter, List.of(summationMode)); + this.summationMode = summationMode; + } + + public Expression summationMode() { + return summationMode; } @Override @@ -73,8 +78,12 @@ protected Expression.TypeResolution resolveType() { ); } - private Avg(StreamInput in) throws IOException { - super(in); + private static Avg readFrom(StreamInput in) throws IOException { + // For BWC and to ensure parameters always include the summation mode, first read a generic AggregateFunction, then convert to AVG. + var fn = readGenericAggregateFunction(in); + var parameters = fn.parameters(); + var summationMode = parameters.isEmpty() ? SummationMode.COMPENSATED_LITERAL : parameters.getFirst(); + return new Avg(fn.source(), fn.field(), fn.filter(), summationMode); } @Override @@ -89,17 +98,17 @@ public DataType dataType() { @Override protected NodeInfo info() { - return NodeInfo.create(this, Avg::new, field(), filter()); + return NodeInfo.create(this, Avg::new, field(), filter(), summationMode); } @Override public Avg replaceChildren(List newChildren) { - return new Avg(source(), newChildren.get(0), newChildren.get(1)); + return new Avg(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2)); } @Override public Avg withFilter(Expression filter) { - return new Avg(source(), field(), filter); + return new Avg(source(), field(), filter, summationMode); } @Override @@ -110,8 +119,8 @@ public Expression surrogate() { return new MvAvg(s, field); } if (field.dataType() == AGGREGATE_METRIC_DOUBLE) { - return new Div(s, new Sum(s, field, filter()).surrogate(), new Count(s, field, filter()).surrogate()); + return new Div(s, new Sum(s, field, filter(), summationMode).surrogate(), new Count(s, field, filter()).surrogate()); } - return new Div(s, new Sum(s, field, filter()), new Count(s, field, filter()), dataType()); + return new Div(s, new Sum(s, field, filter(), summationMode), new Count(s, field, filter()), dataType()); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgOverTime.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgOverTime.java index 5e8c3a9bcf104..0f8c1dddafcb5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgOverTime.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgOverTime.java @@ -104,6 +104,6 @@ public Expression surrogate() { @Override public AggregateFunction perTimeSeriesAggregation() { - return new Avg(source(), field(), filter()); + return new Avg(source(), field(), filter(), SummationMode.LOSSY_LITERAL); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java index d4a940bbf6648..2877682f3749c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.LossySumDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumIntAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier; @@ -32,7 +33,6 @@ import java.io.IOException; import java.util.List; -import static java.util.Collections.emptyList; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; import static org.elasticsearch.xpack.esql.core.type.DataType.AGGREGATE_METRIC_DOUBLE; @@ -44,7 +44,9 @@ * Sum all values of a field in matching documents. */ public class Sum extends NumericAggregate implements SurrogateExpression { - public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Sum", Sum::new); + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Sum", Sum::readFrom); + + private final Expression summationMode; @FunctionInfo( returnType = { "long", "double" }, @@ -61,15 +63,20 @@ public class Sum extends NumericAggregate implements SurrogateExpression { ) } ) public Sum(Source source, @Param(name = "number", type = { "aggregate_metric_double", "double", "integer", "long" }) Expression field) { - this(source, field, Literal.TRUE); + this(source, field, Literal.TRUE, SummationMode.COMPENSATED_LITERAL); } - public Sum(Source source, Expression field, Expression filter) { - super(source, field, filter, emptyList()); + public Sum(Source source, Expression field, Expression filter, Expression summationMode) { + super(source, field, filter, List.of(summationMode)); + this.summationMode = summationMode; } - private Sum(StreamInput in) throws IOException { - super(in); + private static Sum readFrom(StreamInput in) throws IOException { + // For BWC and to ensure parameters always include the summation mode, first read a generic AggregateFunction, then convert to SUM. + var fn = readGenericAggregateFunction(in); + var parameters = fn.parameters(); + var summationMode = parameters.isEmpty() ? SummationMode.COMPENSATED_LITERAL : parameters.getFirst(); + return new Sum(fn.source(), fn.field(), fn.filter(), summationMode); } @Override @@ -78,18 +85,18 @@ public String getWriteableName() { } @Override - protected NodeInfo info() { - return NodeInfo.create(this, Sum::new, field(), filter()); + protected NodeInfo info() { + return NodeInfo.create(this, Sum::new, field(), filter(), summationMode); } @Override public Sum replaceChildren(List newChildren) { - return new Sum(source(), newChildren.get(0), newChildren.get(1)); + return new Sum(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2)); } @Override public Sum withFilter(Expression filter) { - return new Sum(source(), field(), filter); + return new Sum(source(), field(), filter, summationMode); } @Override @@ -110,7 +117,15 @@ protected AggregatorFunctionSupplier intSupplier() { @Override protected AggregatorFunctionSupplier doubleSupplier() { - return new SumDoubleAggregatorFunctionSupplier(); + final SummationMode mode = SummationMode.fromLiteral(summationMode); + return switch (mode) { + case COMPENSATED -> new SumDoubleAggregatorFunctionSupplier(); + case LOSSY -> new LossySumDoubleAggregatorFunctionSupplier(); + }; + } + + public Expression summationMode() { + return summationMode; } @Override @@ -139,7 +154,12 @@ public Expression surrogate() { var s = source(); var field = field(); if (field.dataType() == AGGREGATE_METRIC_DOUBLE) { - return new Sum(s, FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.SUM)); + return new Sum( + s, + FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.SUM), + filter(), + summationMode + ); } // SUM(const) is equivalent to MV_SUM(const)*COUNT(*). diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumOverTime.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumOverTime.java index 3e918e046633c..d0f47b37041c1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumOverTime.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumOverTime.java @@ -85,6 +85,6 @@ public DataType dataType() { @Override public Sum perTimeSeriesAggregation() { - return new Sum(source(), field(), filter()); + return new Sum(source(), field(), filter(), SummationMode.LOSSY_LITERAL); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SummationMode.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SummationMode.java new file mode 100644 index 0000000000000..abd7b2cace4fa --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SummationMode.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.lucene.BytesRefs; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * Specifies the summation algorithm to use for aggregating floating point values. + */ +public enum SummationMode { + /** + * The default mode in regular aggregations. + * Uses Kahan summation for improved floating point precision. + */ + COMPENSATED("compensated"), + + /** + * The default mode in time-series aggregations. + * Uses simple summation, allowing loss of precision for performance. + */ + LOSSY("lossy") + + ; + + public static final Literal COMPENSATED_LITERAL = COMPENSATED.asLiteral(); + public static final Literal LOSSY_LITERAL = LOSSY.asLiteral(); + + private final String mode; + + SummationMode(String mode) { + this.mode = mode; + } + + private Literal asLiteral() { + return Literal.keyword(Source.EMPTY, mode); + } + + public static SummationMode fromLiteral(Expression literal) { + return fromString(BytesRefs.toString(literal.fold(FoldContext.small()))); + } + + private static SummationMode fromString(String mode) { + return switch (mode) { + case "compensated" -> COMPENSATED; + case "lossy" -> LOSSY; + default -> throw new IllegalArgumentException("unknown summation mode [" + mode + "]"); + }; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java index 0c9e20ab16230..cc3528622203f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java @@ -45,6 +45,7 @@ import org.elasticsearch.xpack.esql.expression.UnresolvedNamePattern; import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; +import org.elasticsearch.xpack.esql.expression.function.aggregate.SummationMode; import org.elasticsearch.xpack.esql.plan.IndexPattern; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.ChangePoint; @@ -743,7 +744,11 @@ public PlanFactory visitFuseCommand(EsqlBaseParser.FuseCommandContext ctx) { Attribute idAttr = new UnresolvedAttribute(source, IdFieldMapper.NAME); Attribute indexAttr = new UnresolvedAttribute(source, MetadataAttribute.INDEX); List aggregates = List.of( - new Alias(source, MetadataAttribute.SCORE, new Sum(source, scoreAttr, new Literal(source, true, DataType.BOOLEAN))) + new Alias( + source, + MetadataAttribute.SCORE, + new Sum(source, scoreAttr, new Literal(source, true, DataType.BOOLEAN), SummationMode.COMPENSATED_LITERAL) + ) ); List groupings = List.of(idAttr, indexAttr); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java index 704a57be15310..5d6fa21e20813 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java @@ -7,18 +7,86 @@ package org.elasticsearch.xpack.esql.expression.function.aggregate; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import java.io.IOException; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; public class AvgSerializationTests extends AbstractExpressionSerializationTests { @Override protected Avg createTestInstance() { - return new Avg(randomSource(), randomChild()); + return new Avg(randomSource(), randomChild(), randomChild(), randomChild()); } @Override protected Avg mutateInstance(Avg instance) throws IOException { - return new Avg(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); + Expression field = instance.field(); + Expression filter = instance.filter(); + Expression summationMode = instance.summationMode(); + switch (randomIntBetween(0, 2)) { + case 0 -> field = randomValueOtherThan(field, AbstractExpressionSerializationTests::randomChild); + case 1 -> filter = randomValueOtherThan(filter, AbstractExpressionSerializationTests::randomChild); + case 2 -> summationMode = randomValueOtherThan(summationMode, AbstractExpressionSerializationTests::randomChild); + default -> throw new AssertionError("unexpected value"); + } + return new Avg(instance.source(), field, filter, summationMode); + } + + public static class OldAvg extends AggregateFunction { + public OldAvg(Source source, Expression field, Expression filter) { + super(source, field, filter, List.of()); + } + + @Override + public AggregateFunction withFilter(Expression filter) { + return new OldAvg(source(), filter, filter); + } + + @Override + public DataType dataType() { + return field().dataType(); + } + + @Override + public Expression replaceChildren(List newChildren) { + return new OldAvg(source(), newChildren.get(0), newChildren.get(1)); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, OldAvg::new, field(), filter()); + } + + @Override + public String getWriteableName() { + return Avg.ENTRY.name; + } + } + + public void testSerializeOldAvg() throws IOException { + var oldAvg = new OldAvg(randomSource(), randomChild(), randomChild()); + try (BytesStreamOutput out = new BytesStreamOutput()) { + PlanStreamOutput planOut = new PlanStreamOutput(out, configuration()); + planOut.writeNamedWriteable(oldAvg); + try (StreamInput in = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), getNamedWriteableRegistry())) { + PlanStreamInput planIn = new PlanStreamInput(in, getNamedWriteableRegistry(), configuration()); + Avg serialized = (Avg) planIn.readNamedWriteable(categoryClass()); + assertThat(serialized.source(), equalTo(oldAvg.source())); + assertThat(serialized.field(), equalTo(oldAvg.field())); + assertThat(serialized.filter(), equalTo(oldAvg.filter())); + assertThat(serialized.summationMode(), equalTo(SummationMode.COMPENSATED_LITERAL)); + } + } } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java index 8126b4a30bdb0..f15d8bb96cf24 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java @@ -7,18 +7,85 @@ package org.elasticsearch.xpack.esql.expression.function.aggregate; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import java.io.IOException; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; public class SumSerializationTests extends AbstractExpressionSerializationTests { @Override protected Sum createTestInstance() { - return new Sum(randomSource(), randomChild()); + return new Sum(randomSource(), randomChild(), randomChild(), randomChild()); } @Override protected Sum mutateInstance(Sum instance) throws IOException { - return new Sum(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); + Expression field = instance.field(); + Expression filter = instance.filter(); + Expression summationMode = instance.summationMode(); + switch (randomIntBetween(0, 2)) { + case 0 -> field = randomValueOtherThan(field, AbstractExpressionSerializationTests::randomChild); + case 1 -> filter = randomValueOtherThan(filter, AbstractExpressionSerializationTests::randomChild); + case 2 -> summationMode = randomValueOtherThan(summationMode, AbstractExpressionSerializationTests::randomChild); + default -> throw new AssertionError("unexpected value"); + } + return new Sum(instance.source(), field, filter, summationMode); + } + + public static class OldSum extends AggregateFunction { + public OldSum(Source source, Expression field, Expression filter) { + super(source, field, filter, List.of()); + } + + @Override + public AggregateFunction withFilter(Expression filter) { + return new OldSum(source(), filter, filter); + } + + @Override + public DataType dataType() { + return field().dataType(); + } + + @Override + public Expression replaceChildren(List newChildren) { + return new OldSum(source(), newChildren.get(0), newChildren.get(1)); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, OldSum::new, field(), filter()); + } + + @Override + public String getWriteableName() { + return Sum.ENTRY.name; + } + } + + public void testSerializeOldSum() throws IOException { + var oldSum = new OldSum(randomSource(), randomChild(), randomChild()); + try (BytesStreamOutput out = new BytesStreamOutput()) { + PlanStreamOutput planOut = new PlanStreamOutput(out, configuration()); + planOut.writeNamedWriteable(oldSum); + try (StreamInput in = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), getNamedWriteableRegistry())) { + PlanStreamInput planIn = new PlanStreamInput(in, getNamedWriteableRegistry(), configuration()); + Sum serialized = (Sum) planIn.readNamedWriteable(categoryClass()); + assertThat(serialized.source(), equalTo(oldSum.source())); + assertThat(serialized.field(), equalTo(oldSum.field())); + assertThat(serialized.summationMode(), equalTo(SummationMode.COMPENSATED_LITERAL)); + } + } } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index dd48cfac29ea0..373c8d86e6be7 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -54,6 +54,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile; import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; +import org.elasticsearch.xpack.esql.expression.function.aggregate.SummationMode; import org.elasticsearch.xpack.esql.expression.function.aggregate.ToPartial; import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; @@ -7608,6 +7609,7 @@ public void testTranslateAvgOverTime() { assertThat(Expressions.attribute(finalAgg.groupings().get(0)).id(), equalTo(aggsByTsid.aggregates().get(2).id())); Sum sumTs = as(Alias.unwrap(aggsByTsid.aggregates().get(0)), Sum.class); + assertThat(sumTs.summationMode(), equalTo(SummationMode.LOSSY_LITERAL)); assertThat(Expressions.attribute(sumTs.field()).name(), equalTo("network.bytes_in")); Count countTs = as(Alias.unwrap(aggsByTsid.aggregates().get(1)), Count.class); assertThat(Expressions.attribute(countTs.field()).name(), equalTo("network.bytes_in"));