Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion snuba/web/rpc/v1/resolvers/common/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,11 @@ def get_extrapolated_function(
and_cond(get_field_existence_expression(field), condition_in_aggregation),
**alias_dict,
),
Function.FUNCTION_ANY: f.anyIfOrNull(
field,
and_cond(get_field_existence_expression(field), condition_in_aggregation),
**alias_dict,
),
}

return function_map_sample_weighted.get(aggregation.aggregate)
Expand Down Expand Up @@ -832,6 +837,10 @@ def aggregation_to_expression(
field,
and_cond(get_field_existence_expression(field), condition_in_aggregation),
),
Function.FUNCTION_ANY: f.anyIfOrNull(
field,
and_cond(get_field_existence_expression(field), condition_in_aggregation),
),
}

if aggregation.extrapolation_mode in [
Expand All @@ -845,7 +854,15 @@ def aggregation_to_expression(
else:
agg_func_expr = function_map.get(aggregation.aggregate)
if agg_func_expr is not None:
agg_func_expr = f.round(agg_func_expr, _FLOATING_POINT_PRECISION, **alias_dict)
# Don't apply round() to FUNCTION_ANY since it can return non-numeric types (e.g., strings)
if aggregation.aggregate == Function.FUNCTION_ANY:
agg_func_expr = f.anyIfOrNull(
field,
and_cond(get_field_existence_expression(field), condition_in_aggregation),
**alias_dict,
)
else:
agg_func_expr = f.round(agg_func_expr, _FLOATING_POINT_PRECISION, **alias_dict)

if agg_func_expr is None:
raise BadSnubaRPCRequestException(f"Aggregation not specified for {aggregation.key.name}")
Expand Down
34 changes: 33 additions & 1 deletion snuba/web/rpc/v1/resolvers/common/trace_item_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sentry_protos.snuba.v1.trace_item_attribute_pb2 import (
AttributeKey,
AttributeValue,
Function,
Reliability,
)

Expand All @@ -33,8 +34,39 @@ def _add_converter(column: Column, converters: Dict[str, Callable[[Any], Attribu
raise BadSnubaRPCRequestException(
f"unknown attribute type: {AttributeKey.Type.Name(column.key.type)}"
)
elif column.HasField("aggregation"):
# For FUNCTION_ANY, the result type matches the key type since it returns actual values
if column.aggregation.aggregate == Function.FUNCTION_ANY:
key_type = column.aggregation.key.type
if key_type == AttributeKey.TYPE_STRING:
converters[column.label] = lambda x: AttributeValue(val_str=str(x))
elif key_type == AttributeKey.TYPE_INT:
converters[column.label] = lambda x: AttributeValue(val_int=int(x))
elif key_type == AttributeKey.TYPE_BOOLEAN:
converters[column.label] = lambda x: AttributeValue(val_bool=bool(x))
else:
# Default to double for float/double types
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
else:
# Other aggregation functions return numeric values
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
elif column.HasField("conditional_aggregation"):
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
# For FUNCTION_ANY, the result type matches the key type since it returns actual values
# Note: AggregationToConditionalAggregationVisitor converts aggregation -> conditional_aggregation
if column.conditional_aggregation.aggregate == Function.FUNCTION_ANY:
key_type = column.conditional_aggregation.key.type
if key_type == AttributeKey.TYPE_STRING:
converters[column.label] = lambda x: AttributeValue(val_str=str(x))
elif key_type == AttributeKey.TYPE_INT:
converters[column.label] = lambda x: AttributeValue(val_int=int(x))
elif key_type == AttributeKey.TYPE_BOOLEAN:
converters[column.label] = lambda x: AttributeValue(val_bool=bool(x))
else:
# Default to double for float/double types
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
else:
# Other aggregation functions return numeric values
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
elif column.HasField("formula"):
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
_add_converter(column.formula.left, converters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,54 @@ def test_sum(self) -> None:
),
]

def test_any(self) -> None:
# store a test metric with a value of 1, every second of one hour
granularity_secs = 300
query_duration = 60 * 30
store_spans_timeseries(
BASE_TIME,
1,
3600,
metrics=[DummyMetric("test_metric", get_value=lambda x: 1)],
)

message = TimeSeriesRequest(
meta=RequestMeta(
project_ids=[1, 2, 3],
organization_id=1,
cogs_category="something",
referrer="something",
start_timestamp=Timestamp(seconds=int(BASE_TIME.timestamp())),
end_timestamp=Timestamp(seconds=int(BASE_TIME.timestamp() + query_duration)),
trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN,
),
aggregations=[
AttributeAggregation(
aggregate=Function.FUNCTION_ANY,
key=AttributeKey(type=AttributeKey.TYPE_FLOAT, name="test_metric"),
label="any",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
),
],
Comment on lines +310 to +338
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test is basically the same thing as the test below only one is necessary

granularity_secs=granularity_secs,
)
response = EndpointTimeSeries().execute(message)
expected_buckets = [
Timestamp(seconds=int(BASE_TIME.timestamp()) + secs)
for secs in range(0, query_duration, granularity_secs)
]
# any() returns any value from the group - since all values are 1, we expect 1
assert response.result_timeseries == [
TimeSeries(
label="any",
buckets=expected_buckets,
data_points=[
DataPoint(data=1, data_present=True, sample_count=300)
for _ in range(len(expected_buckets))
],
),
]

def test_with_group_by(self) -> None:
store_spans_timeseries(
BASE_TIME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,62 @@ def test_aggregations_reliable(self) -> None:
),
]

def test_any_extrapolated(self) -> None:
# store a test metric with a value of 50, every second for an hour
granularity_secs = 120
query_duration = 3600
store_timeseries(
BASE_TIME,
1,
3600,
metrics=[DummyMetric("test_metric", get_value=lambda x: 50)],
server_sample_rate=1.0,
)

message = TimeSeriesRequest(
meta=RequestMeta(
project_ids=[1, 2, 3],
organization_id=1,
cogs_category="something",
referrer="something",
start_timestamp=Timestamp(seconds=int(BASE_TIME.timestamp())),
end_timestamp=Timestamp(seconds=int(BASE_TIME.timestamp() + query_duration)),
trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN,
),
aggregations=[
AttributeAggregation(
aggregate=Function.FUNCTION_ANY,
key=AttributeKey(type=AttributeKey.TYPE_FLOAT, name="test_metric"),
label="any(test_metric)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED,
),
],
granularity_secs=granularity_secs,
)
response = EndpointTimeSeries().execute(message)
expected_buckets = [
Timestamp(seconds=int(BASE_TIME.timestamp()) + secs)
for secs in range(0, query_duration, granularity_secs)
]
# any() returns any value from the group - since all values are 50, we expect 50
# Note: any() doesn't have confidence intervals, so reliability is UNSPECIFIED
assert sorted(response.result_timeseries, key=lambda x: x.label) == [
TimeSeries(
label="any(test_metric)",
buckets=expected_buckets,
data_points=[
DataPoint(
data=50,
data_present=True,
reliability=Reliability.RELIABILITY_UNSPECIFIED,
avg_sampling_rate=1,
sample_count=120,
)
for _ in range(len(expected_buckets))
],
),
]

def test_confidence_interval_zero_estimate(self) -> None:
# store a a test metric with a value of 1, every second for an hour
granularity_secs = 120
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,68 @@ def test_table_with_aggregates(self, setup_teardown: Any) -> None:
),
]

def test_any_aggregation_with_string_attribute(self, setup_teardown: Any) -> None:
"""Test that any() aggregation works with string attributes.

The fixture creates 120 spans all with custom_tag="blah".
Using any() on this attribute should return "blah" for each group.
"""
message = TraceItemTableRequest(
meta=RequestMeta(
project_ids=[1, 2, 3],
organization_id=1,
cogs_category="something",
referrer="something",
start_timestamp=START_TIMESTAMP,
end_timestamp=END_TIMESTAMP,
trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN,
),
filter=TraceItemFilter(
exists_filter=ExistsFilter(
key=AttributeKey(type=AttributeKey.TYPE_STRING, name="custom_tag")
)
),
columns=[
Column(key=AttributeKey(type=AttributeKey.TYPE_STRING, name="location")),
Column(
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_ANY,
key=AttributeKey(type=AttributeKey.TYPE_STRING, name="custom_tag"),
label="any(custom_tag)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
),
),
],
group_by=[AttributeKey(type=AttributeKey.TYPE_STRING, name="location")],
order_by=[
TraceItemTableRequest.OrderBy(
column=Column(key=AttributeKey(type=AttributeKey.TYPE_STRING, name="location"))
),
],
limit=5,
)
response = EndpointTraceItemTable().execute(message)

# All spans have custom_tag="blah", so any() should return "blah" for each location group
assert response.column_values == [
TraceItemColumnValues(
attribute_name="location",
results=[
AttributeValue(val_str="backend"),
AttributeValue(val_str="frontend"),
AttributeValue(val_str="mobile"),
],
),
TraceItemColumnValues(
attribute_name="any(custom_tag)",
results=[
AttributeValue(val_str="blah"),
AttributeValue(val_str="blah"),
AttributeValue(val_str="blah"),
],
),
]

def test_table_with_columns_not_in_groupby_backward_compat(self, setup_teardown: Any) -> None:
message = TraceItemTableRequest(
meta=RequestMeta(
Expand Down
Loading