Skip to content

Commit 85b17b3

Browse files
phacopsclaude
andcommitted
feat(eap): Add support for any aggregation function in EAP RPC
Add FUNCTION_ANY to the EAP RPC aggregation functions, enabling queries to return any non-null value from a group. This is useful for retrieving representative values when the specific value doesn't matter. Changes: - Add anyIfOrNull to both extrapolated and non-extrapolated function maps - Skip round() for FUNCTION_ANY to support non-numeric returns - Add proper type-based converters for FUNCTION_ANY in trace item table - Add tests for time series, extrapolation, and string attribute scenarios Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent af61d07 commit 85b17b3

File tree

5 files changed

+217
-2
lines changed

5 files changed

+217
-2
lines changed

snuba/web/rpc/v1/resolvers/common/aggregation.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,11 @@ def get_extrapolated_function(
513513
and_cond(get_field_existence_expression(field), condition_in_aggregation),
514514
**alias_dict,
515515
),
516+
Function.FUNCTION_ANY: f.anyIfOrNull(
517+
field,
518+
and_cond(get_field_existence_expression(field), condition_in_aggregation),
519+
**alias_dict,
520+
),
516521
}
517522

518523
return function_map_sample_weighted.get(aggregation.aggregate)
@@ -832,6 +837,10 @@ def aggregation_to_expression(
832837
field,
833838
and_cond(get_field_existence_expression(field), condition_in_aggregation),
834839
),
840+
Function.FUNCTION_ANY: f.anyIfOrNull(
841+
field,
842+
and_cond(get_field_existence_expression(field), condition_in_aggregation),
843+
),
835844
}
836845

837846
if aggregation.extrapolation_mode in [
@@ -845,7 +854,15 @@ def aggregation_to_expression(
845854
else:
846855
agg_func_expr = function_map.get(aggregation.aggregate)
847856
if agg_func_expr is not None:
848-
agg_func_expr = f.round(agg_func_expr, _FLOATING_POINT_PRECISION, **alias_dict)
857+
# Don't apply round() to FUNCTION_ANY since it can return non-numeric types (e.g., strings)
858+
if aggregation.aggregate == Function.FUNCTION_ANY:
859+
agg_func_expr = f.anyIfOrNull(
860+
field,
861+
and_cond(get_field_existence_expression(field), condition_in_aggregation),
862+
**alias_dict,
863+
)
864+
else:
865+
agg_func_expr = f.round(agg_func_expr, _FLOATING_POINT_PRECISION, **alias_dict)
849866

850867
if agg_func_expr is None:
851868
raise BadSnubaRPCRequestException(f"Aggregation not specified for {aggregation.key.name}")

snuba/web/rpc/v1/resolvers/common/trace_item_table.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sentry_protos.snuba.v1.trace_item_attribute_pb2 import (
1111
AttributeKey,
1212
AttributeValue,
13+
Function,
1314
Reliability,
1415
)
1516

@@ -33,8 +34,39 @@ def _add_converter(column: Column, converters: Dict[str, Callable[[Any], Attribu
3334
raise BadSnubaRPCRequestException(
3435
f"unknown attribute type: {AttributeKey.Type.Name(column.key.type)}"
3536
)
37+
elif column.HasField("aggregation"):
38+
# For FUNCTION_ANY, the result type matches the key type since it returns actual values
39+
if column.aggregation.aggregate == Function.FUNCTION_ANY:
40+
key_type = column.aggregation.key.type
41+
if key_type == AttributeKey.TYPE_STRING:
42+
converters[column.label] = lambda x: AttributeValue(val_str=str(x))
43+
elif key_type == AttributeKey.TYPE_INT:
44+
converters[column.label] = lambda x: AttributeValue(val_int=int(x))
45+
elif key_type == AttributeKey.TYPE_BOOLEAN:
46+
converters[column.label] = lambda x: AttributeValue(val_bool=bool(x))
47+
else:
48+
# Default to double for float/double types
49+
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
50+
else:
51+
# Other aggregation functions return numeric values
52+
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
3653
elif column.HasField("conditional_aggregation"):
37-
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
54+
# For FUNCTION_ANY, the result type matches the key type since it returns actual values
55+
# Note: AggregationToConditionalAggregationVisitor converts aggregation -> conditional_aggregation
56+
if column.conditional_aggregation.aggregate == Function.FUNCTION_ANY:
57+
key_type = column.conditional_aggregation.key.type
58+
if key_type == AttributeKey.TYPE_STRING:
59+
converters[column.label] = lambda x: AttributeValue(val_str=str(x))
60+
elif key_type == AttributeKey.TYPE_INT:
61+
converters[column.label] = lambda x: AttributeValue(val_int=int(x))
62+
elif key_type == AttributeKey.TYPE_BOOLEAN:
63+
converters[column.label] = lambda x: AttributeValue(val_bool=bool(x))
64+
else:
65+
# Default to double for float/double types
66+
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
67+
else:
68+
# Other aggregation functions return numeric values
69+
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
3870
elif column.HasField("formula"):
3971
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
4072
_add_converter(column.formula.left, converters)

tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,54 @@ def test_sum(self) -> None:
307307
),
308308
]
309309

310+
def test_any(self) -> None:
311+
# store a test metric with a value of 1, every second of one hour
312+
granularity_secs = 300
313+
query_duration = 60 * 30
314+
store_spans_timeseries(
315+
BASE_TIME,
316+
1,
317+
3600,
318+
metrics=[DummyMetric("test_metric", get_value=lambda x: 1)],
319+
)
320+
321+
message = TimeSeriesRequest(
322+
meta=RequestMeta(
323+
project_ids=[1, 2, 3],
324+
organization_id=1,
325+
cogs_category="something",
326+
referrer="something",
327+
start_timestamp=Timestamp(seconds=int(BASE_TIME.timestamp())),
328+
end_timestamp=Timestamp(seconds=int(BASE_TIME.timestamp() + query_duration)),
329+
trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN,
330+
),
331+
aggregations=[
332+
AttributeAggregation(
333+
aggregate=Function.FUNCTION_ANY,
334+
key=AttributeKey(type=AttributeKey.TYPE_FLOAT, name="test_metric"),
335+
label="any",
336+
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
337+
),
338+
],
339+
granularity_secs=granularity_secs,
340+
)
341+
response = EndpointTimeSeries().execute(message)
342+
expected_buckets = [
343+
Timestamp(seconds=int(BASE_TIME.timestamp()) + secs)
344+
for secs in range(0, query_duration, granularity_secs)
345+
]
346+
# any() returns any value from the group - since all values are 1, we expect 1
347+
assert response.result_timeseries == [
348+
TimeSeries(
349+
label="any",
350+
buckets=expected_buckets,
351+
data_points=[
352+
DataPoint(data=1, data_present=True, sample_count=300)
353+
for _ in range(len(expected_buckets))
354+
],
355+
),
356+
]
357+
310358
def test_with_group_by(self) -> None:
311359
store_spans_timeseries(
312360
BASE_TIME,

tests/web/rpc/v1/test_endpoint_time_series/test_endpoint_time_series_extrapolation.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,62 @@ def test_aggregations_reliable(self) -> None:
189189
),
190190
]
191191

192+
def test_any_extrapolated(self) -> None:
193+
# store a test metric with a value of 50, every second for an hour
194+
granularity_secs = 120
195+
query_duration = 3600
196+
store_timeseries(
197+
BASE_TIME,
198+
1,
199+
3600,
200+
metrics=[DummyMetric("test_metric", get_value=lambda x: 50)],
201+
server_sample_rate=1.0,
202+
)
203+
204+
message = TimeSeriesRequest(
205+
meta=RequestMeta(
206+
project_ids=[1, 2, 3],
207+
organization_id=1,
208+
cogs_category="something",
209+
referrer="something",
210+
start_timestamp=Timestamp(seconds=int(BASE_TIME.timestamp())),
211+
end_timestamp=Timestamp(seconds=int(BASE_TIME.timestamp() + query_duration)),
212+
trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN,
213+
),
214+
aggregations=[
215+
AttributeAggregation(
216+
aggregate=Function.FUNCTION_ANY,
217+
key=AttributeKey(type=AttributeKey.TYPE_FLOAT, name="test_metric"),
218+
label="any(test_metric)",
219+
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED,
220+
),
221+
],
222+
granularity_secs=granularity_secs,
223+
)
224+
response = EndpointTimeSeries().execute(message)
225+
expected_buckets = [
226+
Timestamp(seconds=int(BASE_TIME.timestamp()) + secs)
227+
for secs in range(0, query_duration, granularity_secs)
228+
]
229+
# any() returns any value from the group - since all values are 50, we expect 50
230+
# Note: any() doesn't have confidence intervals, so reliability is UNSPECIFIED
231+
assert sorted(response.result_timeseries, key=lambda x: x.label) == [
232+
TimeSeries(
233+
label="any(test_metric)",
234+
buckets=expected_buckets,
235+
data_points=[
236+
DataPoint(
237+
data=50,
238+
data_present=True,
239+
reliability=Reliability.RELIABILITY_UNSPECIFIED,
240+
avg_sampling_rate=1,
241+
sample_count=120,
242+
)
243+
for _ in range(len(expected_buckets))
244+
],
245+
),
246+
]
247+
192248
def test_confidence_interval_zero_estimate(self) -> None:
193249
# store a a test metric with a value of 1, every second for an hour
194250
granularity_secs = 120

tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,68 @@ def test_table_with_aggregates(self, setup_teardown: Any) -> None:
788788
),
789789
]
790790

791+
def test_any_aggregation_with_string_attribute(self, setup_teardown: Any) -> None:
792+
"""Test that any() aggregation works with string attributes.
793+
794+
The fixture creates 120 spans all with custom_tag="blah".
795+
Using any() on this attribute should return "blah" for each group.
796+
"""
797+
message = TraceItemTableRequest(
798+
meta=RequestMeta(
799+
project_ids=[1, 2, 3],
800+
organization_id=1,
801+
cogs_category="something",
802+
referrer="something",
803+
start_timestamp=START_TIMESTAMP,
804+
end_timestamp=END_TIMESTAMP,
805+
trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN,
806+
),
807+
filter=TraceItemFilter(
808+
exists_filter=ExistsFilter(
809+
key=AttributeKey(type=AttributeKey.TYPE_STRING, name="custom_tag")
810+
)
811+
),
812+
columns=[
813+
Column(key=AttributeKey(type=AttributeKey.TYPE_STRING, name="location")),
814+
Column(
815+
aggregation=AttributeAggregation(
816+
aggregate=Function.FUNCTION_ANY,
817+
key=AttributeKey(type=AttributeKey.TYPE_STRING, name="custom_tag"),
818+
label="any(custom_tag)",
819+
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
820+
),
821+
),
822+
],
823+
group_by=[AttributeKey(type=AttributeKey.TYPE_STRING, name="location")],
824+
order_by=[
825+
TraceItemTableRequest.OrderBy(
826+
column=Column(key=AttributeKey(type=AttributeKey.TYPE_STRING, name="location"))
827+
),
828+
],
829+
limit=5,
830+
)
831+
response = EndpointTraceItemTable().execute(message)
832+
833+
# All spans have custom_tag="blah", so any() should return "blah" for each location group
834+
assert response.column_values == [
835+
TraceItemColumnValues(
836+
attribute_name="location",
837+
results=[
838+
AttributeValue(val_str="backend"),
839+
AttributeValue(val_str="frontend"),
840+
AttributeValue(val_str="mobile"),
841+
],
842+
),
843+
TraceItemColumnValues(
844+
attribute_name="any(custom_tag)",
845+
results=[
846+
AttributeValue(val_str="blah"),
847+
AttributeValue(val_str="blah"),
848+
AttributeValue(val_str="blah"),
849+
],
850+
),
851+
]
852+
791853
def test_table_with_columns_not_in_groupby_backward_compat(self, setup_teardown: Any) -> None:
792854
message = TraceItemTableRequest(
793855
meta=RequestMeta(

0 commit comments

Comments
 (0)