Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions api/py/ai/chronon/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,14 @@ def validate_group_by(group_by: ttypes.GroupBy):
Keys {unselected_keys}, are unselected in source
"""

# For global aggregations (empty keys), aggregations must be specified
if not keys:
assert aggregations is not None and len(aggregations) > 0, (
"Global aggregations (empty keys) require at least one aggregation to be specified. "
"To compute global aggregates, provide aggregations like "
"[Aggregation(input_column='col', operation=Operation.SUM)]."
)

# Aggregations=None is only valid if group_by is Entities
if aggregations is None:
is_events = any([s.events for s in sources])
Expand Down Expand Up @@ -359,7 +367,7 @@ def get_output_col_names(aggregation):

def GroupBy(
sources: Union[List[_ANY_SOURCE_TYPE], _ANY_SOURCE_TYPE],
keys: List[str],
keys: Optional[List[str]],
aggregations: Optional[List[ttypes.Aggregation]],
online: Optional[bool] = DEFAULT_ONLINE,
production: Optional[bool] = DEFAULT_PRODUCTION,
Expand Down Expand Up @@ -408,8 +416,9 @@ def GroupBy(
:type sources: List[ai.chronon.api.ttypes.Events|ai.chronon.api.ttypes.Entities]
:param keys:
List of primary keys that defines the data that needs to be collected in the result table. Similar to the
GroupBy in the SQL context.
:type keys: List[String]
GroupBy in the SQL context. For global aggregations (computing a single aggregate value across all data),
pass either None or an empty list. In this case, aggregations will be computed without grouping by any keys.
:type keys: Optional[List[String]]
:param aggregations:
List of aggregations that needs to be computed for the data following the grouping defined by the keys::

Expand Down Expand Up @@ -500,11 +509,12 @@ def GroupBy(
"""
assert sources, "Sources are not specified"

key_columns = keys or []
agg_inputs = []
if aggregations is not None:
agg_inputs = [agg.inputColumn for agg in aggregations]

required_columns = keys + agg_inputs
required_columns = key_columns + agg_inputs

def _sanitize_columns(source: ttypes.Source):
query = (
Expand Down Expand Up @@ -577,7 +587,7 @@ def _normalize_source(source):

group_by = ttypes.GroupBy(
sources=sources,
keyColumns=keys,
keyColumns=key_columns,
aggregations=aggregations,
metaData=metadata,
backfillStartDate=backfill_start_date,
Expand Down
159 changes: 78 additions & 81 deletions api/py/test/test_group_by.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Copyright (C) 2023 The Chronon Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -13,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest, json
import json

import pytest
from ai.chronon import group_by, query
from ai.chronon.group_by import GroupBy, Derivation, TimeUnit, Window, Aggregation, Accuracy
from ai.chronon.api import ttypes
from ai.chronon.api.ttypes import EventSource, EntitySource, Operation
from ai.chronon.group_by import Accuracy, Derivation


@pytest.fixture
Expand Down Expand Up @@ -50,11 +49,7 @@ def event_source(table, topic=None):
topic=topic,
query=ttypes.Query(
startPartition="2020-04-09",
selects={
"subject": "subject_sql",
"event_id": "event_sql",
"cnt": 1
},
selects={"subject": "subject_sql", "event_id": "event_sql", "cnt": 1},
timeColumn="CAST(ts AS DOUBLE)",
),
)
Expand All @@ -69,30 +64,21 @@ def entity_source(snapshotTable, mutationTable):
mutationTable=mutationTable,
query=ttypes.Query(
startPartition="2020-04-09",
selects={
"subject": "subject_sql",
"event_id": "event_sql",
"cnt": 1
},
selects={"subject": "subject_sql", "event_id": "event_sql", "cnt": 1},
timeColumn="CAST(ts AS DOUBLE)",
mutationTimeColumn="__mutationTs",
reversalColumn="is_reverse",
),
)


def test_pretty_window_str(days_unit, hours_unit):
"""
Test pretty window utils.
"""
window = ttypes.Window(
length=7,
timeUnit=days_unit
)
window = ttypes.Window(length=7, timeUnit=days_unit)
assert group_by.window_to_str_pretty(window) == "7 days"
window = ttypes.Window(
length=2,
timeUnit=hours_unit
)
window = ttypes.Window(length=2, timeUnit=hours_unit)
assert group_by.window_to_str_pretty(window) == "2 hours"


Expand All @@ -108,7 +94,7 @@ def test_select():
"""
Test select builder
"""
assert query.select('subject', event="event_expr") == {"subject": "subject", "event": "event_expr"}
assert query.select("subject", event="event_expr") == {"subject": "subject", "event": "event_expr"}


def test_contains_windowed_aggregation(sum_op, min_op, days_unit):
Expand All @@ -117,16 +103,12 @@ def test_contains_windowed_aggregation(sum_op, min_op, days_unit):
"""
assert not group_by.contains_windowed_aggregation([])
aggregations = [
ttypes.Aggregation(inputColumn='event', operation=sum_op),
ttypes.Aggregation(inputColumn='event', operation=min_op),
ttypes.Aggregation(inputColumn="event", operation=sum_op),
ttypes.Aggregation(inputColumn="event", operation=min_op),
]
assert not group_by.contains_windowed_aggregation(aggregations)
aggregations.append(
ttypes.Aggregation(
inputColumn='event',
operation=sum_op,
windows=[ttypes.Window(length=7, timeUnit=days_unit)]
)
ttypes.Aggregation(inputColumn="event", operation=sum_op, windows=[ttypes.Window(length=7, timeUnit=days_unit)])
)
assert group_by.contains_windowed_aggregation(aggregations)

Expand Down Expand Up @@ -174,6 +156,7 @@ def test_validator_ok():
aggregations=None,
)


def test_validator_accuracy():
with pytest.raises(AssertionError, match="SNAPSHOT accuracy should not be specified for streaming sources"):
gb = group_by.GroupBy(
Expand All @@ -192,31 +175,21 @@ def test_validator_accuracy():
assert all([agg.inputColumn for agg in gb.aggregations if agg.operation != ttypes.Operation.COUNT])
group_by.validate_group_by(gb)


def test_generic_collector():
aggregation = group_by.Aggregation(
input_column="test", operation=group_by.Operation.APPROX_PERCENTILE([0.4, 0.2]))
aggregation = group_by.Aggregation(input_column="test", operation=group_by.Operation.APPROX_PERCENTILE([0.4, 0.2]))
assert aggregation.argMap == {"k": "128", "percentiles": "[0.4, 0.2]"}


def test_select_sanitization():
gb = group_by.GroupBy(
sources=[
ttypes.EventSource( # No selects are spcified
table="event_table1",
query=query.Query(
selects=None,
time_column="ts"
)
table="event_table1", query=query.Query(selects=None, time_column="ts")
),
ttypes.EntitySource( # Some selects are specified
snapshotTable="entity_table1",
query=query.Query(
selects={
"key1": "key1_sql",
"event_id": "event_sql"
}
)
)
snapshotTable="entity_table1", query=query.Query(selects={"key1": "key1_sql", "event_id": "event_sql"})
),
],
keys=["key1", "key2"],
aggregations=group_by.Aggregations(
Expand All @@ -239,76 +212,100 @@ def test_snapshot_with_hour_aggregation():
ttypes.EntitySource( # Some selects are specified
snapshotTable="entity_table1",
query=query.Query(
selects={
"key1": "key1_sql",
"event_id": "event_sql"
},
selects={"key1": "key1_sql", "event_id": "event_sql"},
time_column="ts",
)
),
)
],
keys=["key1"],
aggregations=group_by.Aggregations(
random=ttypes.Aggregation(inputColumn="event_id", operation=ttypes.Operation.SUM, windows=[
ttypes.Window(1, ttypes.TimeUnit.HOURS),
]),
random=ttypes.Aggregation(
inputColumn="event_id",
operation=ttypes.Operation.SUM,
windows=[
ttypes.Window(1, ttypes.TimeUnit.HOURS),
],
),
),
backfill_start_date="2021-01-04",
)


def test_additional_metadata():
gb = group_by.GroupBy(
sources=[
ttypes.EventSource(
table="event_table1",
query=query.Query(
selects=None,
time_column="ts"
)
)
],
sources=[ttypes.EventSource(table="event_table1", query=query.Query(selects=None, time_column="ts"))],
keys=["key1", "key2"],
aggregations=[group_by.Aggregation(input_column="event_id", operation=ttypes.Operation.SUM)],
tags={"to_deprecate": True}
tags={"to_deprecate": True},
)
assert json.loads(gb.metaData.customJson)['groupby_tags']['to_deprecate']

assert json.loads(gb.metaData.customJson)["groupby_tags"]["to_deprecate"]


def test_group_by_with_description():
gb = group_by.GroupBy(
sources=[
ttypes.EventSource(
table="event_table1",
query=query.Query(
selects=None,
time_column="ts"
)
)
],
sources=[ttypes.EventSource(table="event_table1", query=query.Query(selects=None, time_column="ts"))],
keys=["key1", "key2"],
aggregations=[group_by.Aggregation(input_column="event_id", operation=ttypes.Operation.SUM)],
name="test.additional_metadata_gb",
description="GroupBy description"
description="GroupBy description",
)
assert gb.metaData.description == "GroupBy description"


def test_derivation():
derivation = Derivation(name="derivation_name", expression="derivation_expression")
expected_derivation = ttypes.Derivation(
name="derivation_name",
expression="derivation_expression")
expected_derivation = ttypes.Derivation(name="derivation_name", expression="derivation_expression")

assert derivation == expected_derivation


def test_derivation_with_description():
derivation = Derivation(name="derivation_name", expression="derivation_expression", description="Derivation description")
derivation = Derivation(
name="derivation_name", expression="derivation_expression", description="Derivation description"
)
expected_derivation = ttypes.Derivation(
name="derivation_name",
expression="derivation_expression",
metaData=ttypes.MetaData(description="Derivation description"))
metaData=ttypes.MetaData(description="Derivation description"),
)

assert derivation == expected_derivation
assert derivation == expected_derivation


def test_global_aggregation():
"""
Test global aggregations with empty keys
"""
# Test with keys=[]
gb = group_by.GroupBy(
sources=event_source("table"),
keys=[],
aggregations=group_by.Aggregations(
total_count=ttypes.Aggregation(inputColumn="cnt", operation=ttypes.Operation.COUNT),
total_sum=ttypes.Aggregation(inputColumn="cnt", operation=ttypes.Operation.SUM),
),
)
assert gb.keyColumns == []
assert len(gb.aggregations) == 2
group_by.validate_group_by(gb)

# Test with keys=None
gb = group_by.GroupBy(
sources=event_source("table"),
keys=None,
aggregations=group_by.Aggregations(
total_count=ttypes.Aggregation(inputColumn="cnt", operation=ttypes.Operation.COUNT),
total_sum=ttypes.Aggregation(inputColumn="cnt", operation=ttypes.Operation.SUM),
),
)
assert gb.keyColumns == []
assert len(gb.aggregations) == 2
group_by.validate_group_by(gb)

# Test that global aggregations require aggregations
with pytest.raises(AssertionError, match="Global aggregations"):
fail_gb = group_by.GroupBy(
sources=event_source("table"),
keys=[],
aggregations=None,
)
49 changes: 49 additions & 0 deletions api/py/test/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,52 @@ def test_derivation_with_description():
)

assert derivation == expected_derivation


def test_join_with_global_aggregation():
"""
Test that joins work with global aggregations (GroupBys with empty keys).
Global aggregations should join on system keys (partition/timestamp) only.
"""
# Create a global aggregation GroupBy (no keys)
global_gb = GroupBy(
sources=[event_source("global_stats_table")],
keys=[], # Empty keys = global aggregation
aggregations=[
api.Aggregation(inputColumn="event_id", operation=api.Operation.COUNT),
api.Aggregation(inputColumn="event_id", operation=api.Operation.SUM),
],
name="global_stats",
)

# Create a normal GroupBy with keys for comparison
regular_gb = GroupBy(
sources=[event_source("user_stats_table")],
keys=["subject"],
aggregations=[
api.Aggregation(inputColumn="event_id", operation=api.Operation.LAST),
],
name="user_stats",
)

# Create a join with both global and regular aggregations
join = Join(
left=event_source("events_table"),
right_parts=[
api.JoinPart(groupBy=global_gb, prefix="global"),
api.JoinPart(groupBy=regular_gb), # Uses default key mapping
],
name="events_with_global_and_user_stats",
)

# Verify the join was created successfully
assert join is not None
assert len(join.joinParts) == 2

# Verify global aggregation has empty keyColumns
assert join.joinParts[0].groupBy.keyColumns == []
assert len(join.joinParts[0].groupBy.aggregations) == 2

# Verify regular aggregation has keys
assert join.joinParts[1].groupBy.keyColumns == ["subject"]
assert len(join.joinParts[1].groupBy.aggregations) == 1
1 change: 1 addition & 0 deletions api/src/main/scala/ai/chronon/api/Constants.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,5 @@ object Constants {
val chrononArchiveFlag: String = "chronon_archived"
val ChainingRequestTs: String = "chaining_request_ts"
val ChainingFetchTs: String = "chaining_fetch_ts"
val GlobalAggregationKVStoreKey: String = "__global_aggregation_dummy_key__"
}
Loading