Skip to content

Commit a309133

Browse files
committed
Move new block hashes to typical location
This makes them easier to be tested.
1 parent f2d1806 commit a309133

File tree

4 files changed

+325
-174
lines changed

4 files changed

+325
-174
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.aggregation.blockhash;
9+
10+
import org.elasticsearch.common.io.stream.BytesStreamOutput;
11+
import org.elasticsearch.common.unit.ByteSizeValue;
12+
import org.elasticsearch.compute.data.Block;
13+
import org.elasticsearch.compute.data.BlockFactory;
14+
import org.elasticsearch.compute.data.IntBlock;
15+
import org.elasticsearch.compute.data.Page;
16+
import org.elasticsearch.core.ReleasableIterator;
17+
import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
18+
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
19+
20+
import java.io.IOException;
21+
22+
public abstract class AbstractCategorizeBlockHash extends BlockHash {
23+
private final boolean outputPartial;
24+
protected final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
25+
26+
AbstractCategorizeBlockHash(
27+
BlockFactory blockFactory,
28+
boolean outputPartial,
29+
TokenListCategorizer.CloseableTokenListCategorizer categorizer
30+
) {
31+
super(blockFactory);
32+
this.outputPartial = outputPartial;
33+
this.categorizer = categorizer;
34+
}
35+
36+
@Override
37+
public Block[] getKeys() {
38+
if (outputPartial) {
39+
// NOCOMMIT load partial
40+
Block state = null;
41+
Block keys; // NOCOMMIT do we even need to send the keys? it's just going to be 0 to the length of state
42+
// return new Block[] {new CompositeBlock()};
43+
return null;
44+
}
45+
46+
// NOCOMMIT load final
47+
return new Block[0];
48+
}
49+
50+
@Override
51+
public final ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
52+
throw new UnsupportedOperationException();
53+
}
54+
55+
private Block buildIntermediateBlock(BlockFactory blockFactory, int positionCount) {
56+
if (categorizer.getCategoryCount() == 0) {
57+
return blockFactory.newConstantNullBlock(positionCount);
58+
}
59+
try (BytesStreamOutput out = new BytesStreamOutput()) {
60+
// TODO be more careful here.
61+
out.writeVInt(categorizer.getCategoryCount());
62+
for (SerializableTokenListCategory category : categorizer.toCategories(categorizer.getCategoryCount())) {
63+
category.writeTo(out);
64+
}
65+
return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), positionCount);
66+
} catch (IOException e) {
67+
throw new RuntimeException(e);
68+
}
69+
}
70+
}
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.aggregation.blockhash;
9+
10+
import org.apache.lucene.analysis.TokenStream;
11+
import org.apache.lucene.util.BytesRef;
12+
import org.elasticsearch.common.util.BigArrays;
13+
import org.elasticsearch.common.util.BitArray;
14+
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
15+
import org.elasticsearch.compute.aggregation.Warnings;
16+
import org.elasticsearch.compute.ann.Fixed;
17+
import org.elasticsearch.compute.data.Block;
18+
import org.elasticsearch.compute.data.BlockFactory;
19+
import org.elasticsearch.compute.data.BytesRefBlock;
20+
import org.elasticsearch.compute.data.BytesRefVector;
21+
import org.elasticsearch.compute.data.IntBlock;
22+
import org.elasticsearch.compute.data.IntVector;
23+
import org.elasticsearch.compute.data.Page;
24+
import org.elasticsearch.compute.operator.DriverContext;
25+
import org.elasticsearch.compute.operator.EvalOperator;
26+
import org.elasticsearch.core.Releasables;
27+
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
28+
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
29+
30+
import java.io.IOException;
31+
32+
public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
33+
private final CategorizeEvaluator evaluator;
34+
35+
CategorizeRawBlockHash(
36+
BlockFactory blockFactory,
37+
boolean outputPartial,
38+
TokenListCategorizer.CloseableTokenListCategorizer categorizer,
39+
CategorizeEvaluator evaluator
40+
) {
41+
super(blockFactory, outputPartial, categorizer);
42+
this.evaluator = evaluator;
43+
}
44+
45+
@Override
46+
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
47+
IntBlock result = (IntBlock) evaluator.eval(page);
48+
addInput.add(0, result);
49+
}
50+
51+
@Override
52+
public IntVector nonEmpty() {
53+
// TODO
54+
return null;
55+
}
56+
57+
@Override
58+
public BitArray seenGroupIds(BigArrays bigArrays) {
59+
// TODO
60+
return null;
61+
}
62+
63+
@Override
64+
public void close() {
65+
// TODO
66+
}
67+
68+
/**
69+
* NOCOMMIT: Super-duper copy-pasted.
70+
*/
71+
public static final class CategorizeEvaluator implements EvalOperator.ExpressionEvaluator {
72+
private final Warnings warnings;
73+
74+
private final EvalOperator.ExpressionEvaluator v;
75+
76+
private final CategorizationAnalyzer analyzer;
77+
78+
private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
79+
80+
private final DriverContext driverContext;
81+
82+
static int process(
83+
BytesRef v,
84+
@Fixed(includeInToString = false, build = true) CategorizationAnalyzer analyzer,
85+
@Fixed(includeInToString = false, build = true) TokenListCategorizer.CloseableTokenListCategorizer categorizer
86+
) {
87+
String s = v.utf8ToString();
88+
try (TokenStream ts = analyzer.tokenStream("text", s)) {
89+
return categorizer.computeCategory(ts, s.length(), 1).getId();
90+
} catch (IOException e) {
91+
throw new RuntimeException(e);
92+
}
93+
}
94+
95+
public CategorizeEvaluator(
96+
EvalOperator.ExpressionEvaluator v,
97+
CategorizationAnalyzer analyzer,
98+
TokenListCategorizer.CloseableTokenListCategorizer categorizer,
99+
DriverContext driverContext
100+
) {
101+
this.v = v;
102+
this.analyzer = analyzer;
103+
this.categorizer = categorizer;
104+
this.driverContext = driverContext;
105+
this.warnings = Warnings.createWarnings(driverContext.warningsMode(), -1, -1, "");
106+
}
107+
108+
@Override
109+
public Block eval(Page page) {
110+
try (BytesRefBlock vBlock = (BytesRefBlock) v.eval(page)) {
111+
BytesRefVector vVector = vBlock.asVector();
112+
if (vVector == null) {
113+
return eval(page.getPositionCount(), vBlock);
114+
}
115+
return eval(page.getPositionCount(), vVector).asBlock();
116+
}
117+
}
118+
119+
public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
120+
try (IntBlock.Builder result = driverContext.blockFactory().newIntBlockBuilder(positionCount)) {
121+
BytesRef vScratch = new BytesRef();
122+
position: for (int p = 0; p < positionCount; p++) {
123+
if (vBlock.isNull(p)) {
124+
result.appendNull();
125+
continue position;
126+
}
127+
if (vBlock.getValueCount(p) != 1) {
128+
if (vBlock.getValueCount(p) > 1) {
129+
warnings.registerException(new IllegalArgumentException("single-value function encountered multi-value"));
130+
}
131+
result.appendNull();
132+
continue position;
133+
}
134+
result.appendInt(process(vBlock.getBytesRef(vBlock.getFirstValueIndex(p), vScratch), this.analyzer, this.categorizer));
135+
}
136+
return result.build();
137+
}
138+
}
139+
140+
public IntVector eval(int positionCount, BytesRefVector vVector) {
141+
try (IntVector.FixedBuilder result = driverContext.blockFactory().newIntVectorFixedBuilder(positionCount)) {
142+
BytesRef vScratch = new BytesRef();
143+
position: for (int p = 0; p < positionCount; p++) {
144+
result.appendInt(p, process(vVector.getBytesRef(p, vScratch), this.analyzer, this.categorizer));
145+
}
146+
return result.build();
147+
}
148+
}
149+
150+
@Override
151+
public String toString() {
152+
return "CategorizeEvaluator[" + "v=" + v + "]";
153+
}
154+
155+
@Override
156+
public void close() {
157+
Releasables.closeExpectNoException(v, analyzer, categorizer);
158+
}
159+
}
160+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.aggregation.blockhash;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.common.bytes.BytesArray;
12+
import org.elasticsearch.common.io.stream.StreamInput;
13+
import org.elasticsearch.common.util.BigArrays;
14+
import org.elasticsearch.common.util.BitArray;
15+
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
16+
import org.elasticsearch.compute.data.BlockFactory;
17+
import org.elasticsearch.compute.data.BytesRefBlock;
18+
import org.elasticsearch.compute.data.CompositeBlock;
19+
import org.elasticsearch.compute.data.IntBlock;
20+
import org.elasticsearch.compute.data.IntVector;
21+
import org.elasticsearch.compute.data.Page;
22+
import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
23+
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
24+
25+
import java.io.IOException;
26+
import java.util.Collections;
27+
import java.util.HashMap;
28+
import java.util.Map;
29+
30+
public class CategorizedIntermediateBlockHash extends AbstractCategorizeBlockHash {
31+
private final IntBlockHash hash;
32+
private final int channel;
33+
34+
CategorizedIntermediateBlockHash(
35+
BlockFactory blockFactory,
36+
boolean outputPartial,
37+
TokenListCategorizer.CloseableTokenListCategorizer categorizer,
38+
IntBlockHash hash,
39+
int channel
40+
) {
41+
super(blockFactory, outputPartial, categorizer);
42+
this.hash = hash;
43+
this.channel = channel;
44+
}
45+
46+
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
47+
CompositeBlock block = page.getBlock(channel);
48+
BytesRefBlock groupingState = block.getBlock(0);
49+
BytesRefBlock groups = block.getBlock(0);
50+
Map<Integer, Integer> idMap;
51+
if (groupingState.areAllValuesNull() == false) {
52+
idMap = readIntermediate(groupingState.getBytesRef(0, new BytesRef()));
53+
} else {
54+
idMap = Collections.emptyMap();
55+
}
56+
try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(groups.getTotalValueCount())) {
57+
for (int i = 0; i < groups.getTotalValueCount(); i++) {
58+
newIdsBuilder.appendInt(idMap.get(i));
59+
}
60+
IntBlock newIds = newIdsBuilder.build();
61+
addInput.add(0, hash.add(newIds));
62+
}
63+
}
64+
65+
private Map<Integer, Integer> readIntermediate(BytesRef bytes) {
66+
Map<Integer, Integer> idMap = new HashMap<>();
67+
try (StreamInput in = new BytesArray(bytes).streamInput()) {
68+
int count = in.readVInt();
69+
for (int oldCategoryId = 0; oldCategoryId < count; oldCategoryId++) {
70+
SerializableTokenListCategory category = new SerializableTokenListCategory(in);
71+
int newCategoryId = categorizer.mergeWireCategory(category).getId();
72+
System.err.println("category id map: " + oldCategoryId + " -> " + newCategoryId + " (" + category.getRegex() + ")");
73+
idMap.put(oldCategoryId, newCategoryId);
74+
}
75+
return idMap;
76+
} catch (IOException e) {
77+
throw new RuntimeException(e);
78+
}
79+
}
80+
81+
@Override
82+
public IntVector nonEmpty() {
83+
return hash.nonEmpty();
84+
}
85+
86+
@Override
87+
public BitArray seenGroupIds(BigArrays bigArrays) {
88+
return hash.seenGroupIds(bigArrays);
89+
}
90+
91+
@Override
92+
public void close() {
93+
94+
}
95+
}

0 commit comments

Comments
 (0)