Skip to content

Commit a546757

Browse files
authored
Add support for count_if in decomposing to measures (#1347)
* Add support for count_if, max_by, and min_by in decomposing to measures * Fix gradle wrapper * Only support count_if * Fix * Test * Fix error in countif post-agg * Fix test for countif post-agg
1 parent b2f5d0b commit a546757

File tree

4 files changed

+47
-9
lines changed

4 files changed

+47
-9
lines changed

.github/workflows/test.yml

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,16 +150,13 @@ jobs:
150150
run:
151151
working-directory: ./datajunction-clients/java
152152
steps:
153-
- uses: actions/checkout@v3
153+
- uses: actions/checkout@v4
154154
- name: Set up JDK 17
155-
uses: actions/setup-java@v3
155+
uses: actions/setup-java@v4
156156
with:
157157
java-version: '17'
158158
distribution: 'temurin'
159-
- name: Official Gradle Wrapper Validation Action
160-
uses: gradle/actions/wrapper-validation@v3
159+
- name: Setup Gradle
160+
uses: gradle/actions/setup-gradle@v4
161161
- name: Build with Gradle
162-
uses: gradle/actions/setup-gradle@v3
163-
with:
164-
arguments: build -x test
165-
build-root-directory: ./datajunction-clients/java
162+
run: ./gradlew build -x test

datajunction-server/datajunction_server/sql/decompose.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,11 @@ def __init__(self, query_ast: ast.Query):
2222
self.handlers = {
2323
dj_functions.Sum: self._simple_associative_agg,
2424
dj_functions.Count: self._simple_associative_agg,
25+
dj_functions.CountIf: self._simple_associative_agg,
2526
dj_functions.Max: self._simple_associative_agg,
27+
dj_functions.MaxBy: self._simple_associative_agg,
2628
dj_functions.Min: self._simple_associative_agg,
29+
dj_functions.MinBy: self._simple_associative_agg,
2730
dj_functions.Avg: self._avg,
2831
dj_functions.AnyValue: self._simple_associative_agg,
2932
}
@@ -158,7 +161,7 @@ def update_ast(func, measures: list[Measure]):
158161
),
159162
)
160163
elif (
161-
func.function() == dj_functions.Count
164+
func.function() in (dj_functions.Count, dj_functions.CountIf)
162165
and func.quantifier != ast.SetQuantifier.Distinct
163166
):
164167
func.name.name = "SUM"

datajunction-server/datajunction_server/sql/functions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,8 @@ class CountIf(Function):
12221222
count_if(expr) - Returns the number of true values in expr.
12231223
"""
12241224

1225+
is_aggregation = True
1226+
12251227

12261228
@CountIf.register # type: ignore
12271229
def infer_type(arg: ct.BooleanType) -> ct.IntegerType:
@@ -3244,6 +3246,8 @@ class MaxBy(Function):
32443246
max_by(val, key) - Returns the value of val corresponding to the maximum value of key.
32453247
"""
32463248

3249+
is_aggregation = True
3250+
32473251

32483252
@MaxBy.register # type: ignore
32493253
def infer_type(val: ct.ColumnType, key: ct.ColumnType) -> ct.ColumnType:
@@ -3334,6 +3338,8 @@ class MinBy(Function):
33343338
min_by(val, key) - Returns the value of val corresponding to the minimum value of key.
33353339
"""
33363340

3341+
is_aggregation = True
3342+
33373343

33383344
@MinBy.register # type: ignore
33393345
def infer_type(val: ct.ColumnType, key: ct.ColumnType) -> ct.ColumnType:

datajunction-server/tests/sql/decompose_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,38 @@ def test_unsupported_aggregation_function():
618618
)
619619

620620

621+
def test_count_if():
622+
"""
623+
Test decomposition for count_if.
624+
"""
625+
extractor = MeasureExtractor.from_query_string(
626+
"SELECT CAST(COUNT_IF(ARRAY_CONTAINS(field_a, 'xyz')) AS FLOAT) / COUNT(*) "
627+
"FROM parent_node",
628+
)
629+
measures, derived_sql = extractor.extract()
630+
expected_measures = [
631+
Measure(
632+
name="field_a_count_if_c1f2ed10",
633+
expression="ARRAY_CONTAINS(field_a, 'xyz')",
634+
aggregation="COUNT_IF",
635+
rule=AggregationRule(type=Aggregability.FULL),
636+
),
637+
Measure(
638+
name="count_3389dae3",
639+
expression="*",
640+
aggregation="COUNT",
641+
rule=AggregationRule(type=Aggregability.FULL),
642+
),
643+
]
644+
assert measures == expected_measures
645+
assert str(derived_sql) == str(
646+
parse(
647+
"SELECT CAST(SUM(field_a_count_if_c1f2ed10) AS FLOAT) / SUM(count_3389dae3) "
648+
"FROM parent_node",
649+
),
650+
)
651+
652+
621653
def test_metric_query_with_aliases():
622654
"""
623655
Test behavior when the query contains unsupported aggregation functions. We just return an

0 commit comments

Comments
 (0)