Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.compute.operator;

import org.elasticsearch.common.Rounding;
import org.elasticsearch.compute.Describable;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.GroupingAggregator;
Expand All @@ -23,6 +24,7 @@
public class TimeSeriesAggregationOperator extends HashAggregationOperator {

public record Factory(
Rounding.Prepared timeBucket,
List<BlockHash.GroupSpec> groups,
AggregatorMode aggregatorMode,
List<GroupingAggregator.Factory> aggregators,
Expand All @@ -32,6 +34,7 @@ public record Factory(
public Operator get(DriverContext driverContext) {
// TODO: use TimeSeriesBlockHash when possible
return new TimeSeriesAggregationOperator(
timeBucket,
aggregators,
() -> BlockHash.build(
groups,
Expand All @@ -53,11 +56,15 @@ public String describe() {
}
}

private final Rounding.Prepared timeBucket;

public TimeSeriesAggregationOperator(
Rounding.Prepared timeBucket,
List<GroupingAggregator.Factory> aggregators,
Supplier<BlockHash> blockHash,
DriverContext driverContext
) {
super(aggregators, blockHash, driverContext);
this.timeBucket = timeBucket;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.compute.operator;

import org.elasticsearch.common.Rounding;
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.GroupingAggregator;
Expand Down Expand Up @@ -46,6 +47,7 @@ public record SupplierWithChannels(AggregatorFunctionSupplier supplier, List<Int
public record Initial(
int tsHashChannel,
int timeBucketChannel,
Rounding.Prepared timeBucket,
List<BlockHash.GroupSpec> groupings,
List<SupplierWithChannels> rates,
List<SupplierWithChannels> nonRates,
Expand All @@ -62,6 +64,7 @@ public Operator get(DriverContext driverContext) {
}
aggregators.addAll(valuesAggregatorForGroupings(groupings, timeBucketChannel));
return new TimeSeriesAggregationOperator(
timeBucket,
aggregators,
() -> new TimeSeriesBlockHash(tsHashChannel, timeBucketChannel, driverContext.blockFactory()),
driverContext
Expand All @@ -77,6 +80,7 @@ public String describe() {
public record Intermediate(
int tsHashChannel,
int timeBucketChannel,
Rounding.Prepared timeBucket,
List<BlockHash.GroupSpec> groupings,
List<SupplierWithChannels> rates,
List<SupplierWithChannels> nonRates,
Expand All @@ -97,6 +101,7 @@ public Operator get(DriverContext driverContext) {
new BlockHash.GroupSpec(timeBucketChannel, ElementType.LONG)
);
return new TimeSeriesAggregationOperator(
timeBucket,
aggregators,
() -> BlockHash.build(hashGroups, driverContext.blockFactory(), maxPageSize, true),
driverContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ public void close() {
Operator intialAgg = new TimeSeriesAggregationOperatorFactories.Initial(
1,
3,
rounding,
IntStream.range(0, nonBucketGroupings.size()).mapToObj(n -> new BlockHash.GroupSpec(5 + n, ElementType.BYTES_REF)).toList(),
List.of(new SupplierWithChannels(new RateLongAggregatorFunctionSupplier(unitInMillis), List.of(4, 2))),
List.of(),
Expand All @@ -280,6 +281,7 @@ public void close() {
Operator intermediateAgg = new TimeSeriesAggregationOperatorFactories.Intermediate(
0,
1,
rounding,
IntStream.range(0, nonBucketGroupings.size()).mapToObj(n -> new BlockHash.GroupSpec(5 + n, ElementType.BYTES_REF)).toList(),
List.of(new SupplierWithChannels(new RateLongAggregatorFunctionSupplier(unitInMillis), List.of(2, 3, 4))),
List.of(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,16 +262,7 @@ public boolean foldable() {
@Override
public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
if (field.dataType() == DataType.DATETIME || field.dataType() == DataType.DATE_NANOS) {
Rounding.Prepared preparedRounding;
if (buckets.dataType().isWholeNumber()) {
int b = ((Number) buckets.fold(toEvaluator.foldCtx())).intValue();
long f = foldToLong(toEvaluator.foldCtx(), from);
long t = foldToLong(toEvaluator.foldCtx(), to);
preparedRounding = new DateRoundingPicker(b, f, t).pickRounding().prepareForUnknown();
} else {
assert DataType.isTemporalAmount(buckets.dataType()) : "Unexpected span data type [" + buckets.dataType() + "]";
preparedRounding = DateTrunc.createRounding(buckets.fold(toEvaluator.foldCtx()), DEFAULT_TZ);
}
Rounding.Prepared preparedRounding = getDateRounding(toEvaluator.foldCtx());
return DateTrunc.evaluator(field.dataType(), source(), toEvaluator.apply(field), preparedRounding);
}
if (field.dataType().isNumeric()) {
Expand All @@ -295,6 +286,30 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
throw EsqlIllegalArgumentException.illegalDataType(field.dataType());
}

/**
* Returns the date rounding from this bucket function if the target field is a date type; otherwise, returns null.
*/
public Rounding.Prepared getDateRoundingOrNull(FoldContext foldCtx) {
if (field.dataType() == DataType.DATETIME || field.dataType() == DataType.DATE_NANOS) {
return getDateRounding(foldCtx);
} else {
return null;
}
}

private Rounding.Prepared getDateRounding(FoldContext foldContext) {
assert field.dataType() == DataType.DATETIME || field.dataType() == DataType.DATE_NANOS : "expected date type; got " + field;
if (buckets.dataType().isWholeNumber()) {
int b = ((Number) buckets.fold(foldContext)).intValue();
long f = foldToLong(foldContext, from);
long t = foldToLong(foldContext, to);
return new DateRoundingPicker(b, f, t).pickRounding().prepareForUnknown();
} else {
assert DataType.isTemporalAmount(buckets.dataType()) : "Unexpected span data type [" + buckets.dataType() + "]";
return DateTrunc.createRounding(buckets.fold(foldContext), DEFAULT_TZ);
}
}

private record DateRoundingPicker(int buckets, long from, long to) {
Rounding pickRounding() {
Rounding prev = LARGEST_HUMAN_DATE_ROUNDING;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public TranslateTimeSeriesAggregate() {

@Override
protected LogicalPlan rule(Aggregate aggregate) {
if (aggregate instanceof TimeSeriesAggregate ts) {
if (aggregate instanceof TimeSeriesAggregate ts && ts.timeBucket() == null) {
return translate(ts);
} else {
return aggregate;
Expand Down Expand Up @@ -226,7 +226,8 @@ LogicalPlan translate(TimeSeriesAggregate aggregate) {
newChild.source(),
newChild,
firstPassGroupings,
mergeExpressions(firstPassAggs, firstPassGroupings)
mergeExpressions(firstPassAggs, firstPassGroupings),
(Bucket) Alias.unwrap(timeBucket)
);
return new Aggregate(firstPhase.source(), firstPhase, secondPassGroupings, mergeExpressions(secondPassAggs, secondPassGroupings));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ public PlanFactory visitStatsCommand(EsqlBaseParser.StatsCommandContext ctx) {
final Stats stats = stats(source(ctx), ctx.grouping, ctx.stats);
return input -> {
if (input.anyMatch(p -> p instanceof UnresolvedRelation ur && ur.indexMode() == IndexMode.TIME_SERIES)) {
return new TimeSeriesAggregate(source(ctx), input, stats.groupings, stats.aggregates);
return new TimeSeriesAggregate(source(ctx), input, stats.groupings, stats.aggregates, null);
} else {
return new Aggregate(source(ctx), input, stats.groupings, stats.aggregates);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;

import java.io.IOException;
import java.util.List;
Expand All @@ -28,17 +30,28 @@ public class TimeSeriesAggregate extends Aggregate {
TimeSeriesAggregate::new
);

public TimeSeriesAggregate(Source source, LogicalPlan child, List<Expression> groupings, List<? extends NamedExpression> aggregates) {
private final Bucket timeBucket;

public TimeSeriesAggregate(
Source source,
LogicalPlan child,
List<Expression> groupings,
List<? extends NamedExpression> aggregates,
Bucket timeBucket
) {
super(source, child, groupings, aggregates);
this.timeBucket = timeBucket;
}

public TimeSeriesAggregate(StreamInput in) throws IOException {
super(in);
this.timeBucket = in.readOptionalWriteable(inp -> (Bucket) Bucket.ENTRY.reader.read(inp));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be versioned?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disabled time_series tests in mixed clusters to allow us to move quickly here.

}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeOptionalWriteable(timeBucket);
}

@Override
Expand All @@ -48,16 +61,21 @@ public String getWriteableName() {

@Override
protected NodeInfo<Aggregate> info() {
return NodeInfo.create(this, TimeSeriesAggregate::new, child(), groupings, aggregates);
return NodeInfo.create(this, TimeSeriesAggregate::new, child(), groupings, aggregates, timeBucket);
}

@Override
public TimeSeriesAggregate replaceChild(LogicalPlan newChild) {
return new TimeSeriesAggregate(source(), newChild, groupings, aggregates);
return new TimeSeriesAggregate(source(), newChild, groupings, aggregates, timeBucket);
}

@Override
public TimeSeriesAggregate with(LogicalPlan child, List<Expression> newGroupings, List<? extends NamedExpression> newAggregates) {
return new TimeSeriesAggregate(source(), child, newGroupings, newAggregates);
return new TimeSeriesAggregate(source(), child, newGroupings, newAggregates, timeBucket);
}

@Nullable
public Bucket timeBucket() {
return timeBucket;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@

package org.elasticsearch.xpack.esql.plan.physical;

import org.elasticsearch.common.Rounding;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;

import java.io.IOException;
Expand All @@ -32,25 +36,31 @@ public class TimeSeriesAggregateExec extends AggregateExec {
TimeSeriesAggregateExec::new
);

private final Bucket timeBucket;

public TimeSeriesAggregateExec(
Source source,
PhysicalPlan child,
List<? extends Expression> groupings,
List<? extends NamedExpression> aggregates,
AggregatorMode mode,
List<Attribute> intermediateAttributes,
Integer estimatedRowSize
Integer estimatedRowSize,
Bucket timeBucket
) {
super(source, child, groupings, aggregates, mode, intermediateAttributes, estimatedRowSize);
this.timeBucket = timeBucket;
}

private TimeSeriesAggregateExec(StreamInput in) throws IOException {
super(in);
this.timeBucket = in.readOptionalWriteable(inp -> (Bucket) Bucket.ENTRY.reader.read(inp));
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeOptionalWriteable(timeBucket);
}

@Override
Expand All @@ -68,7 +78,8 @@ protected NodeInfo<AggregateExec> info() {
aggregates(),
getMode(),
intermediateAttributes(),
estimatedRowSize()
estimatedRowSize(),
timeBucket
);
}

Expand All @@ -81,7 +92,8 @@ public TimeSeriesAggregateExec replaceChild(PhysicalPlan newChild) {
aggregates(),
getMode(),
intermediateAttributes(),
estimatedRowSize()
estimatedRowSize(),
timeBucket
);
}

Expand All @@ -93,7 +105,8 @@ public TimeSeriesAggregateExec withMode(AggregatorMode newMode) {
aggregates(),
newMode,
intermediateAttributes(),
estimatedRowSize()
estimatedRowSize(),
timeBucket
);
}

Expand All @@ -106,7 +119,23 @@ protected AggregateExec withEstimatedSize(int estimatedRowSize) {
aggregates(),
getMode(),
intermediateAttributes(),
estimatedRowSize
estimatedRowSize,
timeBucket
);
}

public Bucket timeBucket() {
return timeBucket;
}

public Rounding.Prepared timeBucketRounding(FoldContext foldContext) {
if (timeBucket == null) {
return null;
}
Rounding.Prepared rounding = timeBucket.getDateRoundingOrNull(foldContext);
if (rounding == null) {
throw new EsqlIllegalArgumentException("expected TBUCKET; got ", timeBucket);
}
return rounding;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,9 @@ else if (aggregatorMode.isOutputPartial()) {
s -> aggregatorFactories.add(s.supplier.groupingAggregatorFactory(s.mode, s.channels))
);
// time-series aggregation
if (aggregateExec instanceof TimeSeriesAggregateExec) {
if (aggregateExec instanceof TimeSeriesAggregateExec ts) {
operatorFactory = new TimeSeriesAggregationOperator.Factory(
ts.timeBucketRounding(context.foldCtx()),
groupSpecs.stream().map(GroupSpec::toHashGroupSpec).toList(),
aggregatorMode,
aggregatorFactories,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,16 @@ static List<Attribute> intermediateAttributes(Aggregate aggregate) {
}

static AggregateExec aggExec(Aggregate aggregate, PhysicalPlan child, AggregatorMode aggMode, List<Attribute> intermediateAttributes) {
if (aggregate instanceof TimeSeriesAggregate) {
if (aggregate instanceof TimeSeriesAggregate ts) {
return new TimeSeriesAggregateExec(
aggregate.source(),
child,
aggregate.groupings(),
aggregate.aggregates(),
aggMode,
intermediateAttributes,
null
null,
ts.timeBucket()
);
} else {
return new AggregateExec(
Expand Down
Loading