Skip to content

Commit c011a9c

Browse files
authored
Integrate aggregators to convert result from datafusion (#19441)
Signed-off-by: expani <[email protected]>
1 parent 065c88d commit c011a9c

File tree

7 files changed

+136
-25
lines changed

7 files changed

+136
-25
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.search.aggregations;
10+
11+
import java.util.List;
12+
import java.util.Map;
13+
14+
public interface ShardResultConvertor {
15+
16+
List<InternalAggregation> convert(Map<String, Object[]> shardResult);
17+
18+
}

server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregator.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import org.opensearch.search.aggregations.InternalAggregation;
5252
import org.opensearch.search.aggregations.LeafBucketCollector;
5353
import org.opensearch.search.aggregations.LeafBucketCollectorBase;
54+
import org.opensearch.search.aggregations.ShardResultConvertor;
5455
import org.opensearch.search.aggregations.StarTreeBucketCollector;
5556
import org.opensearch.search.aggregations.StarTreePreComputeCollector;
5657
import org.opensearch.search.aggregations.support.ValuesSource;
@@ -59,7 +60,9 @@
5960
import org.opensearch.search.startree.StarTreeQueryHelper;
6061

6162
import java.io.IOException;
63+
import java.util.ArrayList;
6264
import java.util.Arrays;
65+
import java.util.List;
6366
import java.util.Map;
6467
import java.util.concurrent.atomic.AtomicReference;
6568
import java.util.function.Function;
@@ -71,7 +74,7 @@
7174
*
7275
* @opensearch.internal
7376
*/
74-
class MaxAggregator extends NumericMetricsAggregator.SingleValue implements StarTreePreComputeCollector {
77+
class MaxAggregator extends NumericMetricsAggregator.SingleValue implements StarTreePreComputeCollector, ShardResultConvertor {
7578

7679
final ValuesSource.Numeric valuesSource;
7780
final DocValueFormat formatter;
@@ -280,4 +283,14 @@ public StarTreeBucketCollector getStarTreeBucketCollector(
280283
public void doReset() {
281284
maxes.fill(0, maxes.size(), Double.NEGATIVE_INFINITY);
282285
}
286+
287+
@Override
288+
public List<InternalAggregation> convert(Map<String, Object[]> shardResult) {
289+
Object[] values = shardResult.get(name);
290+
List<InternalAggregation> results = new ArrayList<>(values.length);
291+
for (Object value : values) {
292+
results.add(new InternalMax(name, (Long) value, formatter, metadata()));
293+
}
294+
return results;
295+
}
283296
}

server/src/main/java/org/opensearch/search/aggregations/metrics/MinAggregator.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import org.opensearch.search.aggregations.InternalAggregation;
5252
import org.opensearch.search.aggregations.LeafBucketCollector;
5353
import org.opensearch.search.aggregations.LeafBucketCollectorBase;
54+
import org.opensearch.search.aggregations.ShardResultConvertor;
5455
import org.opensearch.search.aggregations.StarTreeBucketCollector;
5556
import org.opensearch.search.aggregations.StarTreePreComputeCollector;
5657
import org.opensearch.search.aggregations.support.ValuesSource;
@@ -59,6 +60,8 @@
5960
import org.opensearch.search.startree.StarTreeQueryHelper;
6061

6162
import java.io.IOException;
63+
import java.util.ArrayList;
64+
import java.util.List;
6265
import java.util.Map;
6366
import java.util.concurrent.atomic.AtomicReference;
6467
import java.util.function.Function;
@@ -70,7 +73,7 @@
7073
*
7174
* @opensearch.internal
7275
*/
73-
class MinAggregator extends NumericMetricsAggregator.SingleValue implements StarTreePreComputeCollector {
76+
class MinAggregator extends NumericMetricsAggregator.SingleValue implements StarTreePreComputeCollector, ShardResultConvertor {
7477
private static final int MAX_BKD_LOOKUPS = 1024;
7578

7679
final ValuesSource.Numeric valuesSource;
@@ -271,4 +274,14 @@ public StarTreeBucketCollector getStarTreeBucketCollector(
271274
(bucket, metricValue) -> mins.set(bucket, Math.min(mins.get(bucket), NumericUtils.sortableLongToDouble(metricValue)))
272275
);
273276
}
277+
278+
@Override
279+
public List<InternalAggregation> convert(Map<String, Object[]> shardResult) {
280+
Object[] values = shardResult.get(name);
281+
List<InternalAggregation> results = new ArrayList<>(values.length);
282+
for (Object value : values) {
283+
results.add(new InternalMin(name, (Long) value, format, metadata()));
284+
}
285+
return results;
286+
}
274287
}

server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import org.opensearch.search.aggregations.InternalAggregation;
4646
import org.opensearch.search.aggregations.LeafBucketCollector;
4747
import org.opensearch.search.aggregations.LeafBucketCollectorBase;
48+
import org.opensearch.search.aggregations.ShardResultConvertor;
4849
import org.opensearch.search.aggregations.StarTreeBucketCollector;
4950
import org.opensearch.search.aggregations.StarTreePreComputeCollector;
5051
import org.opensearch.search.aggregations.support.ValuesSource;
@@ -53,6 +54,8 @@
5354
import org.opensearch.search.startree.StarTreeQueryHelper;
5455

5556
import java.io.IOException;
57+
import java.util.ArrayList;
58+
import java.util.List;
5659
import java.util.Map;
5760

5861
import static org.opensearch.search.startree.StarTreeQueryHelper.getSupportedStarTree;
@@ -62,7 +65,7 @@
6265
*
6366
* @opensearch.internal
6467
*/
65-
public class SumAggregator extends NumericMetricsAggregator.SingleValue implements StarTreePreComputeCollector {
68+
public class SumAggregator extends NumericMetricsAggregator.SingleValue implements StarTreePreComputeCollector, ShardResultConvertor {
6669

6770
private final ValuesSource.Numeric valuesSource;
6871
private final DocValueFormat format;
@@ -215,4 +218,14 @@ public InternalAggregation buildEmptyAggregation() {
215218
public void doClose() {
216219
Releasables.close(sums, compensations);
217220
}
221+
222+
@Override
223+
public List<InternalAggregation> convert(Map<String, Object[]> shardResult) {
224+
Object[] values = shardResult.get(name);
225+
List<InternalAggregation> results = new ArrayList<>(values.length);
226+
for (Object value : values) {
227+
results.add(new InternalSum(name, (Long) value, format, metadata()));
228+
}
229+
return results;
230+
}
218231
}

server/src/main/java/org/opensearch/search/aggregations/metrics/ValueCountAggregator.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import org.opensearch.search.aggregations.InternalAggregation;
4646
import org.opensearch.search.aggregations.LeafBucketCollector;
4747
import org.opensearch.search.aggregations.LeafBucketCollectorBase;
48+
import org.opensearch.search.aggregations.ShardResultConvertor;
4849
import org.opensearch.search.aggregations.StarTreeBucketCollector;
4950
import org.opensearch.search.aggregations.StarTreePreComputeCollector;
5051
import org.opensearch.search.aggregations.support.ValuesSource;
@@ -53,6 +54,8 @@
5354
import org.opensearch.search.startree.StarTreeQueryHelper;
5455

5556
import java.io.IOException;
57+
import java.util.ArrayList;
58+
import java.util.List;
5659
import java.util.Map;
5760

5861
import static org.opensearch.search.startree.StarTreeQueryHelper.getSupportedStarTree;
@@ -65,7 +68,7 @@
6568
*
6669
* @opensearch.internal
6770
*/
68-
public class ValueCountAggregator extends NumericMetricsAggregator.SingleValue implements StarTreePreComputeCollector {
71+
public class ValueCountAggregator extends NumericMetricsAggregator.SingleValue implements StarTreePreComputeCollector, ShardResultConvertor {
6972

7073
final ValuesSource valuesSource;
7174

@@ -209,4 +212,14 @@ public StarTreeBucketCollector getStarTreeBucketCollector(
209212
(bucket, metricValue) -> counts.increment(bucket, metricValue)
210213
);
211214
}
215+
216+
@Override
217+
public List<InternalAggregation> convert(Map<String, Object[]> shardResult) {
218+
Object[] values = shardResult.get(name);
219+
List<InternalAggregation> results = new ArrayList<>(values.length);
220+
for (Object value : values) {
221+
results.add(new InternalValueCount(name, (Long) value, metadata()));
222+
}
223+
return results;
224+
}
212225
}

server/src/main/java/org/opensearch/search/query/QueryPhase.java

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,7 @@
6060
import org.opensearch.search.aggregations.AggregationProcessor;
6161
import org.opensearch.search.aggregations.DefaultAggregationProcessor;
6262
import org.opensearch.search.aggregations.GlobalAggCollectorManager;
63-
import org.opensearch.search.aggregations.InternalAggregation;
6463
import org.opensearch.search.aggregations.InternalAggregations;
65-
import org.opensearch.search.aggregations.metrics.InternalValueCount;
6664
import org.opensearch.search.internal.ContextIndexSearcher;
6765
import org.opensearch.search.internal.ScrollContext;
6866
import org.opensearch.search.internal.SearchContext;
@@ -75,7 +73,6 @@
7573
import org.opensearch.threadpool.ThreadPool;
7674

7775
import java.io.IOException;
78-
import java.util.ArrayList;
7976
import java.util.LinkedList;
8077
import java.util.List;
8178
import java.util.Map;
@@ -168,28 +165,16 @@ public void execute(SearchContext searchContext) throws QueryPhaseExecutionExcep
168165

169166
// boolean rescore = executeInternal(searchContext, queryPhaseSearcher);
170167

168+
// Post process
169+
final InternalAggregations internalAggregations = SearchEngineResultConversionUtils.convertDFResultGeneric(searchContext);
170+
LOGGER.info("InternalAggregation created is {}", internalAggregations.asList());
171+
searchContext.queryResult().aggregations(internalAggregations);
172+
171173
// if (rescore) { // only if we do a regular search
172174
// rescoreProcessor.process(searchContext);
173175
// }
174176
// suggestProcessor.process(searchContext);
175-
// aggregationProcessor.postProcess(searchContext);
176-
177-
// Post process
178-
// Create a list to store the InternalValueCount objects
179-
// Can we map from the preprocess
180-
List<InternalAggregation> internalAggList = new ArrayList<>();
181-
Map<String, Object[]> map = searchContext.getDFResults();
182-
for (Map.Entry<String, Object[]> entry : map.entrySet()) {
183-
String key = entry.getKey();
184-
Object[] value = entry.getValue();
185-
// SUM, Count will work with integer casting, but (Integer) value casting may not work well for avg
186-
InternalValueCount ivc = new InternalValueCount(key, (long) value[0], null);
187-
internalAggList.add(ivc);
188-
}
189-
190-
final InternalAggregations internalAggregations = InternalAggregations.from(internalAggList);
191-
QuerySearchResult querySearchResult = searchContext.queryResult();
192-
querySearchResult.aggregations(internalAggregations);
177+
aggregationProcessor.postProcess(searchContext);
193178

194179
if (searchContext.getProfilers() != null) {
195180
ProfileShardResult shardResults = SearchProfileShardResults.buildShardResults(
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.search.query;
10+
11+
import org.opensearch.search.aggregations.Aggregator;
12+
import org.opensearch.search.aggregations.InternalAggregations;
13+
import org.opensearch.search.aggregations.ShardResultConvertor;
14+
import org.opensearch.search.internal.SearchContext;
15+
16+
import java.io.IOException;
17+
import java.util.ArrayList;
18+
import java.util.List;
19+
import java.util.Map;
20+
import java.util.stream.Collectors;
21+
22+
public class SearchEngineResultConversionUtils {
23+
24+
public static InternalAggregations convertDFResultGeneric(SearchContext searchContext) {
25+
Map<String, Object[]> dfResult = searchContext.getDFResults();
26+
27+
// Create aggregators which will process the result from DataFusion
28+
try {
29+
30+
List<Aggregator> aggregators = new ArrayList<>();
31+
32+
if (searchContext.aggregations().factories().hasGlobalAggregator()) {
33+
aggregators.addAll(searchContext.aggregations().factories().createTopLevelGlobalAggregators(searchContext));
34+
}
35+
36+
if (searchContext.aggregations().factories().hasNonGlobalAggregator()) {
37+
aggregators.addAll(searchContext.aggregations().factories().createTopLevelNonGlobalAggregators(searchContext));
38+
}
39+
40+
List<ShardResultConvertor> shardResultConvertors = aggregators.stream().map(x -> {
41+
if (x instanceof ShardResultConvertor) {
42+
return ((ShardResultConvertor) x);
43+
} else {
44+
throw new UnsupportedOperationException("Aggregator doesn't support converting results from shard: " + x);
45+
}
46+
}).toList();
47+
48+
return InternalAggregations.from(
49+
shardResultConvertors.stream().flatMap(x -> x.convert(dfResult).stream()).collect(Collectors.toList())
50+
);
51+
} catch (IOException e) {
52+
throw new RuntimeException(e);
53+
}
54+
}
55+
56+
}

0 commit comments

Comments
 (0)