Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,11 @@ private static Operator operator(DriverContext driverContext, String grouping, S
);
default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]");
};
int pageSize = 16 * 1024;
return new HashAggregationOperator(
List.of(supplier(op, dataType, filter).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(groups.size()))),
() -> BlockHash.build(groups, driverContext.blockFactory(), 16 * 1024, false),
() -> BlockHash.build(groups, driverContext.blockFactory(), pageSize, false),
pageSize,
driverContext
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.analysis.AnalysisRegistry;
Expand Down Expand Up @@ -58,12 +59,14 @@ public Operator get(DriverContext driverContext) {
analysisRegistry,
maxPageSize
),
Integer.MAX_VALUE, // TODO: doesn't support chunk yet
driverContext
);
}
return new HashAggregationOperator(
aggregators,
() -> BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false),
maxPageSize,
driverContext
);
}
Expand All @@ -78,9 +81,10 @@ public String describe() {
}
}

private boolean finished;
private Page output;
private final int maxPageSize;
private Emitter emitter;

private boolean blockHashClosed = false;
private final BlockHash blockHash;

private final List<GroupingAggregator> aggregators;
Expand Down Expand Up @@ -112,8 +116,10 @@ public String describe() {
public HashAggregationOperator(
List<GroupingAggregator.Factory> aggregators,
Supplier<BlockHash> blockHash,
int maxPageSize,
DriverContext driverContext
) {
this.maxPageSize = maxPageSize;
this.aggregators = new ArrayList<>(aggregators.size());
this.driverContext = driverContext;
boolean success = false;
Expand All @@ -132,7 +138,7 @@ public HashAggregationOperator(

@Override
public boolean needsInput() {
return finished == false;
return emitter == null;
}

@Override
Expand Down Expand Up @@ -201,59 +207,98 @@ public void close() {

@Override
public Page getOutput() {
Page p = output;
if (p != null) {
rowsEmitted += p.getPositionCount();
if (emitter == null) {
return null;
}
output = null;
return p;
return emitter.nextPage();
}

@Override
public void finish() {
if (finished) {
return;
private class Emitter implements Releasable {
private final int[] aggBlockCounts;
private int position = -1;
private IntVector allSelected = null;
private Block[] allKeys;

Emitter(int[] aggBlockCounts) {
this.aggBlockCounts = aggBlockCounts;
}
finished = true;
Block[] blocks = null;
IntVector selected = null;
boolean success = false;
try {
selected = blockHash.nonEmpty();
Block[] keys = blockHash.getKeys();
int[] aggBlockCounts = aggregators.stream().mapToInt(GroupingAggregator::evaluateBlockCount).toArray();
blocks = new Block[keys.length + Arrays.stream(aggBlockCounts).sum()];
System.arraycopy(keys, 0, blocks, 0, keys.length);
int offset = keys.length;
for (int i = 0; i < aggregators.size(); i++) {
var aggregator = aggregators.get(i);
aggregator.evaluate(blocks, offset, selected, driverContext);
offset += aggBlockCounts[i];

Page nextPage() {
if (position == -1) {
position = 0;
// TODO: chunk selected and keys
allKeys = blockHash.getKeys();
allSelected = blockHash.nonEmpty();
blockHashClosed = true;
blockHash.close();
}
output = new Page(blocks);
success = true;
} finally {
// selected should always be closed
if (selected != null) {
selected.close();
final int endPosition = Math.toIntExact(Math.min(position + (long) maxPageSize, allSelected.getPositionCount()));
if (endPosition == position) {
return null;
}
if (success == false && blocks != null) {
Releasables.closeExpectNoException(blocks);
final boolean singlePage = position == 0 && endPosition == allSelected.getPositionCount();
final Block[] blocks = new Block[allKeys.length + Arrays.stream(aggBlockCounts).sum()];
IntVector selected = null;
boolean success = false;
try {
if (singlePage) {
this.allSelected.incRef();
selected = this.allSelected;
for (int i = 0; i < allKeys.length; i++) {
allKeys[i].incRef();
blocks[i] = allKeys[i];
}
} else {
final int[] positions = new int[endPosition - position];
for (int i = 0; i < positions.length; i++) {
positions[i] = position + i;
}
// TODO: allow to filter with IntVector
selected = allSelected.filter(positions);
for (int keyIndex = 0; keyIndex < allKeys.length; keyIndex++) {
blocks[keyIndex] = allKeys[keyIndex].filter(positions);
}
}
int blockOffset = allKeys.length;
for (int i = 0; i < aggregators.size(); i++) {
aggregators.get(i).evaluate(blocks, blockOffset, selected, driverContext);
blockOffset += aggBlockCounts[i];
}
var output = new Page(blocks);
rowsEmitted += output.getPositionCount();
success = true;
return output;
} finally {
position = endPosition;
Releasables.close(selected, success ? null : Releasables.wrap(blocks));
}
}

@Override
public void close() {
Releasables.close(allSelected, allKeys != null ? Releasables.wrap(allKeys) : null);
}

boolean doneEmitting() {
return allSelected != null && position >= allSelected.getPositionCount();
}
}

@Override
public void finish() {
if (emitter == null) {
emitter = new Emitter(aggregators.stream().mapToInt(GroupingAggregator::evaluateBlockCount).toArray());
}
}

@Override
public boolean isFinished() {
return finished && output == null;
return emitter != null && emitter.doneEmitting();
}

@Override
public void close() {
if (output != null) {
output.releaseBlocks();
}
Releasables.close(blockHash, () -> Releasables.close(aggregators));
Releasables.close(emitter, blockHashClosed ? null : blockHash, () -> Releasables.close(aggregators));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,16 @@ public Page getOutput() {
return null;
}
if (valuesAggregator != null) {
try {
return valuesAggregator.getOutput();
} finally {
final ValuesAggregator aggregator = this.valuesAggregator;
this.valuesAggregator = null;
Releasables.close(aggregator);
final Page output = valuesAggregator.getOutput();
if (output == null) {
Releasables.close(valuesAggregator, () -> this.valuesAggregator = null);
} else {
return output;
}
}
if (ordinalAggregators.isEmpty() == false) {
try {
// TODO: chunk output pages
return mergeOrdinalsSegmentResults();
} catch (IOException e) {
throw new UncheckedIOException(e);
Expand Down Expand Up @@ -510,6 +510,7 @@ private static class ValuesAggregator implements Releasable {
maxPageSize,
false
),
maxPageSize,
driverContext
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public Operator get(DriverContext driverContext) {
return new HashAggregationOperator(
aggregators,
() -> new TimeSeriesBlockHash(tsHashChannel, timeBucketChannel, driverContext),
maxPageSize,
driverContext
);
}
Expand Down Expand Up @@ -99,6 +100,7 @@ public Operator get(DriverContext driverContext) {
return new HashAggregationOperator(
aggregators,
() -> BlockHash.build(hashGroups, driverContext.blockFactory(), maxPageSize, false),
maxPageSize,
driverContext
);
}
Expand Down Expand Up @@ -127,6 +129,7 @@ public Operator get(DriverContext driverContext) {
return new HashAggregationOperator(
aggregators,
() -> BlockHash.build(groupings, driverContext.blockFactory(), maxPageSize, false),
maxPageSize,
driverContext
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ public String toString() {
randomPageSize(),
false
),
randomPageSize(),
driverContext
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.function.Function;
Expand Down Expand Up @@ -102,13 +103,25 @@ private Operator.OperatorFactory simpleWithMode(
if (randomBoolean()) {
supplier = chunkGroups(emitChunkSize, supplier);
}
return new HashAggregationOperator.HashAggregationOperatorFactory(
final int maxPageSize = randomPageSize();
final var hashOperatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(
List.of(new BlockHash.GroupSpec(0, ElementType.LONG)),
mode,
List.of(supplier.groupingAggregatorFactory(mode, channels(mode))),
randomPageSize(),
maxPageSize,
null
);
return new Operator.OperatorFactory() {
@Override
public Operator get(DriverContext driverContext) {
return assertingOutputPageSize(hashOperatorFactory.get(driverContext), driverContext.blockFactory(), maxPageSize);
}

@Override
public String describe() {
return hashOperatorFactory.describe();
}
};
}

@Override
Expand Down Expand Up @@ -761,4 +774,79 @@ public String describe() {
};
}

static Operator assertingOutputPageSize(Operator operator, BlockFactory blockFactory, int maxPageSize) {
return new Operator() {
private final List<Page> pages = new ArrayList<>();

@Override
public boolean needsInput() {
return operator.needsInput();
}

@Override
public void addInput(Page page) {
operator.addInput(page);
}

@Override
public void finish() {
operator.finish();
}

@Override
public boolean isFinished() {
return operator.isFinished();
}

@Override
public Page getOutput() {
final Page page = operator.getOutput();
if (page != null && page.getPositionCount() > maxPageSize) {
page.releaseBlocks();
throw new AssertionError(
String.format(
Locale.ROOT,
"Operator %s didn't chunk output pages properly; got an output page with %s positions, max_page_size=%s",
operator,
page.getPositionCount(),
maxPageSize
)
);
}
if (page != null) {
pages.add(page);
}
if (operator.isFinished()) {
// TODO: Remove this workaround. We need to merge pages since we have many existing assertions expect a single out page.
try {
return BlockTestUtils.mergePages(blockFactory, pages);
} finally {
pages.forEach(Page::releaseBlocks);
pages.clear();
}
} else {
return null;
}
}

@Override
public Status status() {
return operator.status();
}

@Override
public String toString() {
return operator.toString();
}

@Override
public void close() {
for (Page p : pages) {
p.releaseBlocks();
}
operator.close();
}
};
}

}
Loading