Skip to content

Commit 91a5895

Browse files
phacopsclaude
andauthored
feat(eap): Add support for any aggregation function in EAP RPC (#7660)
## Summary - Add `FUNCTION_ANY` to the EAP RPC aggregation functions - Support both non-extrapolated and extrapolated queries - Support string, int, and boolean attribute types for `any()` results - Add comprehensive tests ## Details The `any()` aggregation function returns any non-null value from a group. This is useful for retrieving representative values when the specific value doesn't matter. Implementation: - Added `anyIfOrNull` to `aggregation_to_expression()` for non-extrapolated queries - Added `anyIfOrNull` to `get_extrapolated_function()` for extrapolated queries - Skip `round()` wrapper for `FUNCTION_ANY` since it can return non-numeric types - Added type-based converters for `FUNCTION_ANY` in trace item table to properly handle string/int/boolean results **Note:** This requires `sentry-protos>=0.4.14` which includes `FUNCTION_ANY` (value 13). ## Test plan - [x] Added `test_any` for basic time series aggregation without extrapolation - [x] Added `test_any_extrapolated` for time series aggregation with sample-weighted extrapolation - [x] Added `test_any_aggregation_with_string_attribute` - inserts many spans with the same string attribute value (`custom_tag="blah"`) and verifies `any()` returns the string correctly in trace item table - [x] Verified existing aggregation tests still pass 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.5 <[email protected]>
1 parent 82e3a15 commit 91a5895

File tree

5 files changed

+228
-17
lines changed

5 files changed

+228
-17
lines changed

snuba/web/rpc/v1/resolvers/common/aggregation.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,11 @@ def get_extrapolated_function(
513513
and_cond(get_field_existence_expression(field), condition_in_aggregation),
514514
**alias_dict,
515515
),
516+
Function.FUNCTION_ANY: f.anyIfOrNull(
517+
field,
518+
and_cond(get_field_existence_expression(field), condition_in_aggregation),
519+
**alias_dict,
520+
),
516521
}
517522

518523
return function_map_sample_weighted.get(aggregation.aggregate)
@@ -832,6 +837,10 @@ def aggregation_to_expression(
832837
field,
833838
and_cond(get_field_existence_expression(field), condition_in_aggregation),
834839
),
840+
Function.FUNCTION_ANY: f.anyIfOrNull(
841+
field,
842+
and_cond(get_field_existence_expression(field), condition_in_aggregation),
843+
),
835844
}
836845

837846
if aggregation.extrapolation_mode in [
@@ -845,7 +854,15 @@ def aggregation_to_expression(
845854
else:
846855
agg_func_expr = function_map.get(aggregation.aggregate)
847856
if agg_func_expr is not None:
848-
agg_func_expr = f.round(agg_func_expr, _FLOATING_POINT_PRECISION, **alias_dict)
857+
# Don't apply round() to FUNCTION_ANY since it can return non-numeric types (e.g., strings)
858+
if aggregation.aggregate == Function.FUNCTION_ANY:
859+
agg_func_expr = f.anyIfOrNull(
860+
field,
861+
and_cond(get_field_existence_expression(field), condition_in_aggregation),
862+
**alias_dict,
863+
)
864+
else:
865+
agg_func_expr = f.round(agg_func_expr, _FLOATING_POINT_PRECISION, **alias_dict)
849866

850867
if agg_func_expr is None:
851868
raise BadSnubaRPCRequestException(f"Aggregation not specified for {aggregation.key.name}")

snuba/web/rpc/v1/resolvers/common/trace_item_table.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,37 +10,65 @@
1010
from sentry_protos.snuba.v1.trace_item_attribute_pb2 import (
1111
AttributeKey,
1212
AttributeValue,
13+
Function,
1314
Reliability,
1415
)
1516

1617
from snuba.web.rpc.common.exceptions import BadSnubaRPCRequestException
1718
from snuba.web.rpc.v1.resolvers.common.aggregation import ExtrapolationContext
1819

1920

21+
def _get_converter_for_type(
22+
key_type: "AttributeKey.Type.ValueType",
23+
) -> Callable[[Any], AttributeValue]:
24+
"""Returns a converter function for the given attribute type."""
25+
if key_type == AttributeKey.TYPE_BOOLEAN:
26+
return lambda x: AttributeValue(val_bool=bool(x))
27+
elif key_type == AttributeKey.TYPE_STRING:
28+
return lambda x: AttributeValue(val_str=str(x))
29+
elif key_type == AttributeKey.TYPE_INT:
30+
return lambda x: AttributeValue(val_int=int(x))
31+
elif key_type == AttributeKey.TYPE_FLOAT:
32+
return lambda x: AttributeValue(val_float=float(x))
33+
elif key_type == AttributeKey.TYPE_DOUBLE:
34+
return lambda x: AttributeValue(val_double=float(x))
35+
else:
36+
raise BadSnubaRPCRequestException(
37+
f"unknown attribute type: {AttributeKey.Type.Name(key_type)}"
38+
)
39+
40+
41+
def _get_double_converter() -> Callable[[Any], AttributeValue]:
42+
"""Returns a converter that converts to double (used for most aggregations)."""
43+
return lambda x: AttributeValue(val_double=float(x))
44+
45+
2046
def _add_converter(column: Column, converters: Dict[str, Callable[[Any], AttributeValue]]) -> None:
2147
if column.HasField("key"):
22-
if column.key.type == AttributeKey.TYPE_BOOLEAN:
23-
converters[column.label] = lambda x: AttributeValue(val_bool=bool(x))
24-
elif column.key.type == AttributeKey.TYPE_STRING:
25-
converters[column.label] = lambda x: AttributeValue(val_str=str(x))
26-
elif column.key.type == AttributeKey.TYPE_INT:
27-
converters[column.label] = lambda x: AttributeValue(val_int=int(x))
28-
elif column.key.type == AttributeKey.TYPE_FLOAT:
29-
converters[column.label] = lambda x: AttributeValue(val_float=float(x))
30-
elif column.key.type == AttributeKey.TYPE_DOUBLE:
31-
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
48+
converters[column.label] = _get_converter_for_type(column.key.type)
49+
elif column.HasField("aggregation"):
50+
# For FUNCTION_ANY, the result type matches the key type since it returns actual values
51+
if column.aggregation.aggregate == Function.FUNCTION_ANY:
52+
converters[column.label] = _get_converter_for_type(column.aggregation.key.type)
3253
else:
33-
raise BadSnubaRPCRequestException(
34-
f"unknown attribute type: {AttributeKey.Type.Name(column.key.type)}"
35-
)
54+
# Other aggregation functions return numeric values
55+
converters[column.label] = _get_double_converter()
3656
elif column.HasField("conditional_aggregation"):
37-
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
57+
# For FUNCTION_ANY, the result type matches the key type since it returns actual values
58+
# Note: AggregationToConditionalAggregationVisitor converts aggregation -> conditional_aggregation
59+
if column.conditional_aggregation.aggregate == Function.FUNCTION_ANY:
60+
converters[column.label] = _get_converter_for_type(
61+
column.conditional_aggregation.key.type
62+
)
63+
else:
64+
# Other aggregation functions return numeric values
65+
converters[column.label] = _get_double_converter()
3866
elif column.HasField("formula"):
39-
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
67+
converters[column.label] = _get_double_converter()
4068
_add_converter(column.formula.left, converters)
4169
_add_converter(column.formula.right, converters)
4270
elif column.HasField("literal"):
43-
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
71+
converters[column.label] = _get_double_converter()
4472
else:
4573
raise BadSnubaRPCRequestException(
4674
"column is not one of: attribute, (conditional) aggregation, or formula"

tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,54 @@ def test_sum(self) -> None:
307307
),
308308
]
309309

310+
def test_any(self) -> None:
311+
# store a test metric with a value of 1, every second of one hour
312+
granularity_secs = 300
313+
query_duration = 60 * 30
314+
store_spans_timeseries(
315+
BASE_TIME,
316+
1,
317+
3600,
318+
metrics=[DummyMetric("test_metric", get_value=lambda x: 1)],
319+
)
320+
321+
message = TimeSeriesRequest(
322+
meta=RequestMeta(
323+
project_ids=[1, 2, 3],
324+
organization_id=1,
325+
cogs_category="something",
326+
referrer="something",
327+
start_timestamp=Timestamp(seconds=int(BASE_TIME.timestamp())),
328+
end_timestamp=Timestamp(seconds=int(BASE_TIME.timestamp() + query_duration)),
329+
trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN,
330+
),
331+
aggregations=[
332+
AttributeAggregation(
333+
aggregate=Function.FUNCTION_ANY,
334+
key=AttributeKey(type=AttributeKey.TYPE_FLOAT, name="test_metric"),
335+
label="any",
336+
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
337+
),
338+
],
339+
granularity_secs=granularity_secs,
340+
)
341+
response = EndpointTimeSeries().execute(message)
342+
expected_buckets = [
343+
Timestamp(seconds=int(BASE_TIME.timestamp()) + secs)
344+
for secs in range(0, query_duration, granularity_secs)
345+
]
346+
# any() returns any value from the group - since all values are 1, we expect 1
347+
assert response.result_timeseries == [
348+
TimeSeries(
349+
label="any",
350+
buckets=expected_buckets,
351+
data_points=[
352+
DataPoint(data=1, data_present=True, sample_count=300)
353+
for _ in range(len(expected_buckets))
354+
],
355+
),
356+
]
357+
310358
def test_with_group_by(self) -> None:
311359
store_spans_timeseries(
312360
BASE_TIME,

tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_extrapolation.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,62 @@ def test_aggregations_reliable(self) -> None:
189189
),
190190
]
191191

192+
def test_any_extrapolated(self) -> None:
193+
# store a test metric with a value of 50, every second for an hour
194+
granularity_secs = 120
195+
query_duration = 3600
196+
store_timeseries(
197+
BASE_TIME,
198+
1,
199+
3600,
200+
metrics=[DummyMetric("test_metric", get_value=lambda x: 50)],
201+
server_sample_rate=1.0,
202+
)
203+
204+
message = TimeSeriesRequest(
205+
meta=RequestMeta(
206+
project_ids=[1, 2, 3],
207+
organization_id=1,
208+
cogs_category="something",
209+
referrer="something",
210+
start_timestamp=Timestamp(seconds=int(BASE_TIME.timestamp())),
211+
end_timestamp=Timestamp(seconds=int(BASE_TIME.timestamp() + query_duration)),
212+
trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN,
213+
),
214+
aggregations=[
215+
AttributeAggregation(
216+
aggregate=Function.FUNCTION_ANY,
217+
key=AttributeKey(type=AttributeKey.TYPE_FLOAT, name="test_metric"),
218+
label="any(test_metric)",
219+
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED,
220+
),
221+
],
222+
granularity_secs=granularity_secs,
223+
)
224+
response = EndpointTimeSeries().execute(message)
225+
expected_buckets = [
226+
Timestamp(seconds=int(BASE_TIME.timestamp()) + secs)
227+
for secs in range(0, query_duration, granularity_secs)
228+
]
229+
# any() returns any value from the group - since all values are 50, we expect 50
230+
# Note: any() doesn't have confidence intervals, so reliability is UNSPECIFIED
231+
assert sorted(response.result_timeseries, key=lambda x: x.label) == [
232+
TimeSeries(
233+
label="any(test_metric)",
234+
buckets=expected_buckets,
235+
data_points=[
236+
DataPoint(
237+
data=50,
238+
data_present=True,
239+
reliability=Reliability.RELIABILITY_UNSPECIFIED,
240+
avg_sampling_rate=1,
241+
sample_count=120,
242+
)
243+
for _ in range(len(expected_buckets))
244+
],
245+
),
246+
]
247+
192248
def test_confidence_interval_zero_estimate(self) -> None:
193249
# store a a test metric with a value of 1, every second for an hour
194250
granularity_secs = 120

tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,68 @@ def test_table_with_aggregates(self, setup_teardown: Any) -> None:
788788
),
789789
]
790790

791+
def test_any_aggregation_with_string_attribute(self, setup_teardown: Any) -> None:
792+
"""Test that any() aggregation works with string attributes.
793+
794+
The fixture creates 120 spans all with custom_tag="blah".
795+
Using any() on this attribute should return "blah" for each group.
796+
"""
797+
message = TraceItemTableRequest(
798+
meta=RequestMeta(
799+
project_ids=[1, 2, 3],
800+
organization_id=1,
801+
cogs_category="something",
802+
referrer="something",
803+
start_timestamp=START_TIMESTAMP,
804+
end_timestamp=END_TIMESTAMP,
805+
trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN,
806+
),
807+
filter=TraceItemFilter(
808+
exists_filter=ExistsFilter(
809+
key=AttributeKey(type=AttributeKey.TYPE_STRING, name="custom_tag")
810+
)
811+
),
812+
columns=[
813+
Column(key=AttributeKey(type=AttributeKey.TYPE_STRING, name="location")),
814+
Column(
815+
aggregation=AttributeAggregation(
816+
aggregate=Function.FUNCTION_ANY,
817+
key=AttributeKey(type=AttributeKey.TYPE_STRING, name="custom_tag"),
818+
label="any(custom_tag)",
819+
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
820+
),
821+
),
822+
],
823+
group_by=[AttributeKey(type=AttributeKey.TYPE_STRING, name="location")],
824+
order_by=[
825+
TraceItemTableRequest.OrderBy(
826+
column=Column(key=AttributeKey(type=AttributeKey.TYPE_STRING, name="location"))
827+
),
828+
],
829+
limit=5,
830+
)
831+
response = EndpointTraceItemTable().execute(message)
832+
833+
# All spans have custom_tag="blah", so any() should return "blah" for each location group
834+
assert response.column_values == [
835+
TraceItemColumnValues(
836+
attribute_name="location",
837+
results=[
838+
AttributeValue(val_str="backend"),
839+
AttributeValue(val_str="frontend"),
840+
AttributeValue(val_str="mobile"),
841+
],
842+
),
843+
TraceItemColumnValues(
844+
attribute_name="any(custom_tag)",
845+
results=[
846+
AttributeValue(val_str="blah"),
847+
AttributeValue(val_str="blah"),
848+
AttributeValue(val_str="blah"),
849+
],
850+
),
851+
]
852+
791853
def test_table_with_columns_not_in_groupby_backward_compat(self, setup_teardown: Any) -> None:
792854
message = TraceItemTableRequest(
793855
meta=RequestMeta(

0 commit comments

Comments
 (0)