Skip to content

Commit 640411a

Browse files
committed
✨ Add metrics aggregation for numeric types
1 parent b06bcb1 commit 640411a

File tree

5 files changed

+103
-1
lines changed

5 files changed

+103
-1
lines changed

openaleph_search/parse/parser.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
198198
# expand query with name synonyms (name_symbols and name_keys)
199199
self.synonyms = self.getbool("synonyms", False)
200200

201+
# metric aggregations (sum, avg, min, max) on numeric fields
202+
self.metrics = self.prefixed_items("metric:")
203+
201204
@cached_property
202205
def collection_ids(self) -> set[str]:
203206
collections = self.filters.get("collection_id", set())
@@ -334,4 +337,5 @@ def to_dict(self) -> dict[str, Any]:
334337
parser["synonyms"] = self.synonyms
335338
parser["include_fields"] = list(self.include_fields)
336339
parser["dehydrate"] = self.dehydrate
340+
parser["metrics"] = {key: list(val) for key, val in self.metrics.items()}
337341
return parser

openaleph_search/query/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class Query:
4040
"score": "_score",
4141
}
4242
SORT_DEFAULT: ClassVar[list[str | dict[str, Any]]] = ["_score"]
43+
METRIC_TYPES: ClassVar[tuple[str, ...]] = ("sum", "avg", "min", "max")
4344
SOURCE: ClassVar[dict[str, Any]] = {}
4445

4546
def __init__(self, parser: SearchQueryParser) -> None:
@@ -279,6 +280,15 @@ def get_aggregations(self) -> dict[str, Any]:
279280
},
280281
}
281282

283+
# Metric aggregations (sum, avg, min, max) on numeric fields
284+
for metric_type, fields in self.parser.metrics.items():
285+
if metric_type not in self.METRIC_TYPES:
286+
continue
287+
for field in fields:
288+
es_field = f"{Field.NUMERIC}.{field}"
289+
agg_name = f"{field}.{metric_type}"
290+
aggregations[agg_name] = {metric_type: {"field": es_field}}
291+
282292
return aggregations
283293

284294
def get_significant_background(self) -> BoolQuery | None:

tests/test_search.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,64 @@ def test_search_synonyms_name_keys(cleanup_after):
617617
assert result["hits"]["hits"][0]["_id"] == "darc-limited-company"
618618

619619

620+
def test_search_metric_aggregations(cleanup_after):
621+
"""Test metric aggregations (sum, avg, min, max) on numeric fields"""
622+
entities = [
623+
make_entity(
624+
{
625+
"id": "payment1",
626+
"schema": "Payment",
627+
"properties": {"amount": ["100"], "date": ["2024-01-01"]},
628+
}
629+
),
630+
make_entity(
631+
{
632+
"id": "payment2",
633+
"schema": "Payment",
634+
"properties": {"amount": ["250"], "date": ["2024-02-01"]},
635+
}
636+
),
637+
make_entity(
638+
{
639+
"id": "payment3",
640+
"schema": "Payment",
641+
"properties": {"amount": ["150"], "date": ["2024-03-01"]},
642+
}
643+
),
644+
]
645+
index_bulk("test_metrics", entities, sync=True)
646+
647+
# Test sum (filter:schemata=Interval to include the intervals index)
648+
query = _create_query(
649+
"/search?filter:dataset=test_metrics&filter:schemata=Interval"
650+
"&metric:sum=amount"
651+
)
652+
result = query.search()
653+
assert result["hits"]["total"]["value"] == 3
654+
assert result["aggregations"]["amount.sum"]["value"] == 500.0
655+
656+
# Test multiple metrics at once
657+
query = _create_query(
658+
"/search?filter:dataset=test_metrics&filter:schemata=Interval"
659+
"&metric:sum=amount&metric:avg=amount&metric:min=amount&metric:max=amount"
660+
)
661+
result = query.search()
662+
aggs = result["aggregations"]
663+
assert aggs["amount.sum"]["value"] == 500.0
664+
assert aggs["amount.avg"]["value"] == pytest.approx(500.0 / 3)
665+
assert aggs["amount.min"]["value"] == 100.0
666+
assert aggs["amount.max"]["value"] == 250.0
667+
668+
# Test with a filter narrowing results
669+
query = _create_query(
670+
"/search?filter:dataset=test_metrics&filter:schemata=Interval"
671+
"&filter:gte:properties.date=2024-02-01&metric:sum=amount"
672+
)
673+
result = query.search()
674+
assert result["hits"]["total"]["value"] == 2
675+
assert result["aggregations"]["amount.sum"]["value"] == 400.0
676+
677+
620678
def test_search_translation_plaintext(cleanup_after):
621679
"""Test that PlainText translatedText is searchable via ES copy_to into the
622680
translation field."""

tests/test_search_parser.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from unittest import TestCase
22

3-
from openaleph_search.parse.parser import QueryParser
3+
from openaleph_search.parse.parser import QueryParser, SearchQueryParser
44

55
args = QueryParser(
66
[
@@ -75,6 +75,15 @@ def test_to_dict(self):
7575
self.assertEqual(set(parser_dict["filters"]["key2"]), set(["foo3", "foo5"]))
7676
self.assertEqual(set(parser_dict["filters"]["key3"]), set(["foo4"]))
7777

78+
def test_metric_parsing(self):
79+
from werkzeug.datastructures import OrderedMultiDict
80+
81+
parser = SearchQueryParser(
82+
OrderedMultiDict([("metric:sum", "amount"), ("metric:sum", "salary")]),
83+
None,
84+
)
85+
assert parser.metrics == {"sum": {"amount", "salary"}}
86+
7887
def test_limit_zero(self):
7988
"""Test that limit=0 is preserved and not converted to default."""
8089
# Test with limit=0 in query args

tests/test_search_query.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,27 @@ def test_highlight_text(self):
176176
"bar",
177177
)
178178

179+
def test_metric_aggregations(self):
180+
q = query(
181+
[
182+
("metric:sum", "amount"),
183+
("metric:avg", "amount"),
184+
("metric:min", "registrationArea"),
185+
]
186+
)
187+
aggs = q.get_aggregations()
188+
self.assertEqual(aggs["amount.sum"], {"sum": {"field": "numeric.amount"}})
189+
self.assertEqual(aggs["amount.avg"], {"avg": {"field": "numeric.amount"}})
190+
self.assertEqual(
191+
aggs["registrationArea.min"],
192+
{"min": {"field": "numeric.registrationArea"}},
193+
)
194+
195+
def test_metric_invalid_type(self):
196+
q = query([("metric:percentile", "amount")])
197+
aggs = q.get_aggregations()
198+
self.assertNotIn("amount.percentile", aggs)
199+
179200
def test_schema_filter(self):
180201
q = query([("filter:schema", "Person")])
181202
assert q.get_filters() == [{"term": {"schema": "Person"}}]

0 commit comments

Comments
 (0)